diff --git a/tests/cpp/operator/test_cast_transpose.cu b/tests/cpp/operator/test_cast_transpose.cu index 4a548ddf6f..8c168c76f4 100644 --- a/tests/cpp/operator/test_cast_transpose.cu +++ b/tests/cpp/operator/test_cast_transpose.cu @@ -81,7 +81,10 @@ std::vector> test_cases = {{2048, 12288}, {65536, 128}, {256, 256}, {120, 2080}, - {8, 8}}; + {8, 8}, + {1, 3221}, // Prime 456 + {2333, 1}, // Prime 345 + {1481, 677}}; // Primes 234, 123 } // namespace class CTTestSuite : public ::testing::TestWithParam #include -#include -#include -#include -#include "../utils.cuh" -#include "../common.h" - -namespace transformer_engine { - -template -inline __device__ void cast_and_transpose_regs(const IVec (&in)[nvec_out], - OVec (&out_trans)[nvec_in], - typename OVec::type *output_cast_tile, - const size_t current_place, - const size_t stride, - CType &max, // NOLINT(*) - const CType scale, - const bool valid_store) { - using T = typename OVec::type; - using OVecC = Vec; -#pragma unroll - for (unsigned int i = 0; i < nvec_out; ++i) { - OVecC out_cast; -#pragma unroll - for (unsigned int j = 0; j < nvec_in; ++j) { - const CType tmp = static_cast(in[i].data.elt[j]); - const T elt_o = T(scale * tmp); - - out_cast.data.elt[j] = elt_o; - out_trans[j].data.elt[i] = elt_o; // thread tile transpose - - __builtin_assume(max >= 0); - max = fmaxf(fabsf(tmp), max); - } - if (full_tile || valid_store) { - out_cast.store_to(output_cast_tile, current_place + stride * i); - } - } -} +#include -// STUFF TO TUNE -constexpr unsigned int n_warps_per_tile = 4; +#include -constexpr unsigned int max_threads_per_block = 256; -static_assert(n_warps_per_tile * THREADS_PER_WARP <= max_threads_per_block); -constexpr unsigned int cast_transpose_num_threads = n_warps_per_tile * THREADS_PER_WARP; +#include "../common.h" +#include "../util/rtc.h" +#include "../util/string.h" +#include "../utils.cuh" -template -__global__ void -__launch_bounds__(cast_transpose_num_threads) -cast_transpose_kernel(const IType * const input, - const CType * const noop, - OType * const output_c, - OType * const output_t, - const CType * const scale_ptr, - CType * const amax, - const size_t row_length, - const size_t num_rows, - const size_t num_tiles) { - if (noop != nullptr && noop[0] == 1.0f) return; +namespace transformer_engine { - using IVec = Vec; - using OVec = Vec; - - extern __shared__ char scratch[]; - - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; - const size_t num_tiles_x = row_length / (nvec_in * THREADS_PER_WARP); - const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + - warp_id / n_warps_per_tile; - if (tile_id >= num_tiles) return; - const size_t tile_id_x = tile_id % num_tiles_x; - const size_t tile_id_y = tile_id / num_tiles_x; - - const IType * const my_input_tile = input + (tile_id_x * nvec_in + - tile_id_y * row_length * nvec_out) * - THREADS_PER_WARP; - OType * const my_output_c_tile = output_c + (tile_id_x * nvec_in + - tile_id_y * row_length * nvec_out) * - THREADS_PER_WARP; - OType * const my_output_t_tile = output_t + (tile_id_y * nvec_out + - tile_id_x * num_rows * nvec_in) * - THREADS_PER_WARP; - OVec * const my_scratch = reinterpret_cast(scratch) + - (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * - (THREADS_PER_WARP + 1); - - IVec in[2][nvec_out]; - const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; - constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; - OVec out_space[n_iterations][nvec_in]; - - const size_t stride = row_length / nvec_in; - const size_t output_stride = num_rows / nvec_out; - size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; - unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - CType max = 0; - const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; -#pragma unroll - for (unsigned int i = 0; i < nvec_out; ++i) { - in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); - } -#pragma unroll - for (unsigned int i = 0; i < n_iterations; ++i) { - const size_t current_place = current_stride + my_place; - const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - const unsigned int current_in = (i + 1) % 2; - if (i < n_iterations - 1) { -#pragma unroll - for (unsigned int j = 0; j < nvec_out; ++j) { - in[current_in][j].load_from(my_input_tile, - current_stride + my_place_in + stride * (nvec_out + j)); - } +namespace { + +// String with RTC kernel implementation +#include "string_code_transpose_rtc_cast_transpose_cu.h" + +// Hard-coded kernel parameters +using CType = float; +constexpr size_t warps_per_tile = 4; +constexpr size_t block_size = THREADS_PER_WARP * warps_per_tile; + +/* Performance heuristics for optimized kernel parameters */ +struct KernelConfig { + /** Vector load size */ + size_t load_size = 0; + /** Vector store size to transposed output */ + size_t store_size = 0; + + /* Whether config is valid */ + bool valid = false; + /* Number of CUDA blocks */ + size_t num_blocks = 0; + + /* Number of active SMs */ + size_t active_sm_count = 0; + /* Elements per L1 cache load */ + size_t elements_per_load = 0; + /* Elements per L1 cache store to cast output*/ + size_t elements_per_store_c = 0; + /* Elements per L1 cache store to transposed output */ + size_t elements_per_store_t = 0; + + KernelConfig(size_t row_length, + size_t num_rows, + size_t itype_size, + size_t otype_size, + size_t load_size_, + size_t store_size_) + : load_size{load_size_} + , store_size{store_size_} { + // Check that tiles are correctly aligned + constexpr size_t cache_line_size = 128; + if (load_size % itype_size != 0 + || store_size % otype_size != 0 + || cache_line_size % itype_size != 0 + || cache_line_size % otype_size != 0) { + return; } - OVec out_trans[nvec_in]; // NOLINT(*) - cast_and_transpose_regs(in[current_in ^ 1], out_trans, my_output_c_tile, - current_place, stride, max, scale, true); -#pragma unroll - for (unsigned int j = 0; j < nvec_in; ++j) { - out_space[i][j].data.vec = out_trans[j].data.vec; + const size_t row_tile_elements = load_size * THREADS_PER_WARP / itype_size; + const size_t col_tile_elements = store_size * THREADS_PER_WARP / otype_size; + valid = (row_length % row_tile_elements == 0 + && num_rows % col_tile_elements == 0); + if (!valid) { + return; } - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += nvec_out * stride; - } - for (unsigned int i = 0; i < nvec_in; ++i) { -#pragma unroll - for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[(my_id_in_warp + THREADS_PER_WARP - - j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; - } - __syncthreads(); - my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - current_stride = i * output_stride + - warp_id_in_tile * n_iterations * output_stride * nvec_in; - for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, - current_stride + my_place); - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += output_stride * nvec_in; - } - __syncthreads(); + // Number of CUDA blocks + num_blocks = (row_length / row_tile_elements) * (num_rows / col_tile_elements); + + // Parameters for performance model + constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs + active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm), + static_cast(cuda::sm_count())); + elements_per_load = (std::min(cache_line_size, row_tile_elements * itype_size) + / itype_size); + elements_per_store_c = (std::min(cache_line_size, row_tile_elements * otype_size) + / otype_size); + elements_per_store_t = (std::min(cache_line_size, col_tile_elements * otype_size) + / otype_size); } - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); - - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - if (amax != nullptr) atomicMaxFloat(amax, max); + /* Compare by estimated cost */ + bool operator<(const KernelConfig &other) const { + if (this->valid && other.valid) { + // cost ~ (1/elements_per_load + // + 1/elements_per_store_c + // + 1/elements_per_store_t) / active_sms + // Note: Integer arithmetic ensures stable ordering + const auto &l1 = this->elements_per_load; + const auto &sc1 = this->elements_per_store_c; + const auto &st1 = this->elements_per_store_t; + const auto &p1 = this->active_sm_count; + const auto &l2 = other.elements_per_load; + const auto &sc2 = other.elements_per_store_c; + const auto &st2 = other.elements_per_store_t; + const auto &p2 = other.active_sm_count; + const auto scale = l1 * sc1 * st1 * p1 * l2 * sc2 * st2 * p2; + const auto cost1 = (scale/l1 + scale/sc1 + scale/st1) / p1; + const auto cost2 = (scale/l2 + scale/sc2 + scale/st2) / p2; + return cost1 < cost2; + } else { + return this->valid && !other.valid; + } } -} +}; -template +template __global__ void -__launch_bounds__(cast_transpose_num_threads) -cast_transpose_kernel_notaligned(const IType * const input, - const CType * const noop, - OType * const output_c, - OType * const output_t, - const CType * const scale_ptr, - CType * const amax, - const size_t row_length, - const size_t num_rows, - const size_t num_tiles) { +__launch_bounds__(block_size) +cast_transpose_general_kernel(const IType * __restrict__ const input, + const CType * __restrict__ const noop, + OType * __restrict__ const output_c, + OType * __restrict__ const output_t, + const CType * __restrict__ const scale_ptr, + CType * __restrict__ const amax_ptr, + const size_t row_length, + const size_t num_rows) { if (noop != nullptr && noop[0] == 1.0f) return; + // Vectorized load/store sizes + constexpr size_t nvec_in = load_size / sizeof(IType); + constexpr size_t nvec_out = store_size / sizeof(OType); using IVec = Vec; - using OVec = Vec; - - extern __shared__ char scratch[]; - - const int warp_id = threadIdx.x / THREADS_PER_WARP; - const int my_id_in_warp = threadIdx.x % THREADS_PER_WARP; - const size_t num_tiles_x = (row_length + nvec_in * THREADS_PER_WARP - 1) / - (nvec_in * THREADS_PER_WARP); - const size_t tile_id = blockIdx.x * blockDim.x / (THREADS_PER_WARP * n_warps_per_tile) + - warp_id / n_warps_per_tile; - if (tile_id >= num_tiles) return; - const size_t tile_id_x = tile_id % num_tiles_x; - const size_t tile_id_y = tile_id / num_tiles_x; - - const IType * const my_input_tile = input + (tile_id_x * nvec_in + - tile_id_y * row_length * nvec_out) * - THREADS_PER_WARP; - OType * const my_output_c_tile = output_c + (tile_id_x * nvec_in + - tile_id_y * row_length * nvec_out) * - THREADS_PER_WARP; - OType * const my_output_t_tile = output_t + (tile_id_y * nvec_out + - tile_id_x * num_rows * nvec_in) * - THREADS_PER_WARP; - const size_t stride = row_length / nvec_in; - const size_t output_stride = num_rows / nvec_out; - const size_t row_length_rest = stride - tile_id_x * THREADS_PER_WARP; - const size_t row_height_rest = output_stride - tile_id_y * THREADS_PER_WARP; - const unsigned int tile_length = row_length_rest > THREADS_PER_WARP ? THREADS_PER_WARP - : row_length_rest; - const unsigned int tile_height = row_height_rest > THREADS_PER_WARP ? THREADS_PER_WARP - : row_height_rest; - - OVec * const my_scratch = reinterpret_cast(scratch) + - (my_id_in_warp + warp_id / n_warps_per_tile * THREADS_PER_WARP) * - (THREADS_PER_WARP + 1); - - IVec in[2][nvec_out]; - const unsigned int warp_id_in_tile = warp_id % n_warps_per_tile; - constexpr unsigned int n_iterations = THREADS_PER_WARP / n_warps_per_tile; - OVec out_space[n_iterations][nvec_in]; - - size_t current_stride = warp_id_in_tile * n_iterations * nvec_out * stride; - unsigned int my_place = (my_id_in_warp + THREADS_PER_WARP - - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - CType max = 0; - const CType scale = scale_ptr != nullptr ? *scale_ptr : 1; - { - const bool valid_load = my_place < tile_length && - warp_id_in_tile * n_iterations < tile_height; -#pragma unroll - for (unsigned int i = 0; i < nvec_out; ++i) { - if (valid_load) { - in[0][i].load_from(my_input_tile, current_stride + my_place + stride * i); - } else { - in[0][i].clear(); - } - } - } -#pragma unroll - for (unsigned int i = 0; i < n_iterations; ++i) { - const size_t current_place = current_stride + my_place; - const unsigned int my_place_in = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - const unsigned int current_in = (i + 1) % 2; - if (i < n_iterations - 1) { - const bool valid_load = my_place_in < tile_length && - warp_id_in_tile * n_iterations + i + 1 < tile_height; -#pragma unroll - for (unsigned int j = 0; j < nvec_out; ++j) { - if (valid_load) { - in[current_in][j].load_from(my_input_tile, - current_stride + my_place_in + stride * (nvec_out + j)); - } else { - in[current_in][j].clear(); + using OVecT = Vec; + + // Thread indices + // Note: Block is interpreted as a warp_size x num_warps grid + constexpr size_t bdimx = THREADS_PER_WARP; + constexpr size_t bdimy = warps_per_tile; + const size_t tid = threadIdx.x; + const size_t tidx = tid % bdimx; + const size_t tidy = tid / bdimx; + const size_t bid = blockIdx.x; + + // Input tensors are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles + constexpr size_t tile_dim_m = THREADS_PER_WARP * nvec_out; + constexpr size_t tile_dim_n = THREADS_PER_WARP * nvec_in; + + // Position of tile within tensor + const size_t num_tiles_m = (num_rows + tile_dim_m - 1) / tile_dim_m; + const size_t tile_id_m = bid % num_tiles_m; + const size_t tile_id_n = bid / num_tiles_m; + const size_t tile_row = tile_id_m * tile_dim_m; + const size_t tile_col = tile_id_n * tile_dim_n; + + // Number of nvec_out x nvec_in subtiles for each thread to + // load/store + constexpr size_t num_iterations = THREADS_PER_WARP / warps_per_tile; + + // FP8 factors + const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr; + CType amax = 0; + + // Load input and store to registers + // Note: Each thread loads num_iterations subtiles, computes amax, + // casts type, and transposes in registers. + OVecT local_output_t[nvec_in][num_iterations]; + #pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidy + iter * bdimy; + const size_t j1 = tidx; + #pragma unroll + for (size_t i2 = 0; i2 < nvec_out; ++i2) { + const size_t row = tile_row + i1 * nvec_out + i2; + const size_t col = tile_col + j1 * nvec_in; + if (row < num_rows) { + #pragma unroll + for (size_t j2 = 0; j2 < nvec_in; ++j2) { + if (col + j2 < row_length) { + const CType in = input[row * row_length + col + j2]; + const OType out = OType(in * scale); + __builtin_assume(amax >= 0); + amax = fmaxf(fabsf(in), amax); + output_c[row * row_length + col + j2] = out; + local_output_t[j2][iter].data.elt[i2] = out; } } + } } - OVec out_trans[nvec_in]; // NOLINT(*) - const bool valid_store = my_place < tile_length && - warp_id_in_tile * n_iterations + i < tile_height; - cast_and_transpose_regs(in[current_in ^ 1], out_trans, my_output_c_tile, - current_place, stride, max, scale, valid_store); -#pragma unroll - for (unsigned int j = 0; j < nvec_in; ++j) { - out_space[i][j].data.vec = out_trans[j].data.vec; - } - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += nvec_out * stride; } - for (unsigned int i = 0; i < nvec_in; ++i) { -#pragma unroll - for (unsigned int j = 0; j < n_iterations; ++j) { - my_scratch[(my_id_in_warp + THREADS_PER_WARP - - j - warp_id_in_tile * n_iterations) % THREADS_PER_WARP] = out_space[j][i]; + // Copy transposed output from registers to global memory + __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1]; + #pragma unroll + for (size_t j2 = 0; j2 < nvec_in; ++j2) { + #pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidy + iter * bdimy; + const size_t j1 = tidx; + shared_output_t[j1][i1] = local_output_t[j2][iter]; } __syncthreads(); - my_place = (my_id_in_warp + THREADS_PER_WARP - warp_id_in_tile * n_iterations) % - THREADS_PER_WARP; - current_stride = i * output_stride + - warp_id_in_tile * n_iterations * output_stride * nvec_in; - for (unsigned int j = 0; warp_id_in_tile * n_iterations + j < tile_length; ++j) { - const bool valid_store = my_place < tile_height; - if (valid_store) { - my_scratch[j + warp_id_in_tile * n_iterations].store_to(my_output_t_tile, - current_stride + my_place); + #pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidx; + const size_t j1 = tidy + iter * bdimy; + const size_t row = tile_row + i1 * nvec_out; + const size_t col = tile_col + j1 * nvec_in + j2; + if (col < row_length) { + #pragma unroll + for (size_t i2 = 0; i2 < nvec_out; ++i2) { + if (row + i2 < num_rows) { + output_t[col * num_rows + row + i2] = shared_output_t[j1][i1].data.elt[i2]; + } + } } - my_place = (my_place + THREADS_PER_WARP - 1) % THREADS_PER_WARP; - current_stride += output_stride * nvec_in; } __syncthreads(); } - /* warp tile amax reduce*/ - max = reduce_max(max, warp_id); - - if (threadIdx.x == 0) { - static_assert(std::is_same::value); - if (amax != nullptr) atomicMaxFloat(amax, max); + // Reduce amax over block + if (amax_ptr != nullptr) { + amax = reduce_max(amax, tidy); + if (threadIdx.x == 0) { + atomicMaxFloat(amax_ptr, amax); + } } } +} // namespace + void cast_transpose(const Tensor &input, const Tensor &noop, - Tensor *cast_output, - Tensor *transposed_output, + Tensor *cast_output_, + Tensor *transposed_output_, cudaStream_t stream) { - CheckInputTensor(input, "cast_transpose_input"); - CheckOutputTensor(*cast_output, "cast_output"); - CheckOutputTensor(*transposed_output, "transposed_output"); - - // Number of elements in tensor - auto numel = [] (const Tensor &tensor) -> size_t { - size_t acc = 1; - for (const auto& dim : tensor.data.shape) { - acc *= dim; - } - return acc; - }; + Tensor &cast_output = *cast_output_; + Tensor &transposed_output = *transposed_output_; + // Check no-op flag if (noop.data.dptr != nullptr) { - NVTE_CHECK(numel(noop) == 1, - "Expected 1 element, ", - "but found ", numel(noop), "."); + size_t numel = 1; + for (const auto& dim : noop.data.shape) { + numel *= dim; + } + NVTE_CHECK(numel == 1, "Expected 1 element, but found ", numel, "."); NVTE_CHECK(noop.data.dtype == DType::kFloat32); NVTE_CHECK(noop.data.dptr != nullptr); } + + // Check tensor dims + CheckInputTensor(input, "cast_transpose_input"); + CheckOutputTensor(cast_output, "cast_output"); + CheckOutputTensor(transposed_output, "transposed_output"); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); - NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); - NVTE_CHECK(transposed_output->data.shape.size() == 2, "T output must have 2 dimensions."); - NVTE_CHECK(input.data.shape == cast_output->data.shape, - "Input and C output must have the same shape."); + NVTE_CHECK(cast_output.data.shape.size() == 2, "Cast output must have 2 dimensions."); + NVTE_CHECK(transposed_output.data.shape.size() == 2, + "Transposed output must have 2 dimensions."); const size_t row_length = input.data.shape[1]; const size_t num_rows = input.data.shape[0]; - - NVTE_CHECK(transposed_output->data.shape[0] == row_length, "Wrong dimension of T output."); - NVTE_CHECK(transposed_output->data.shape[1] == num_rows, "Wrong dimension of T output."); - - NVTE_CHECK(cast_output->data.dtype == transposed_output->data.dtype, - "C and T outputs need to have the same type."); - NVTE_CHECK(cast_output->amax.dptr == transposed_output->amax.dptr, - "C and T outputs need to share amax tensor."); - NVTE_CHECK(cast_output->scale.dptr == transposed_output->scale.dptr, - "C and T outputs need to share scale tensor."); - -// Launch specific cast-transpose kernel -#define LAUNCH_KERNEL(kernel, nvec_in, nvec_out, n_tiles, n_blocks, InputType, OutputType) \ - do { \ - cudaFuncSetAttribute(kernel, \ - cudaFuncAttributePreferredSharedMemoryCarveout, \ - 100); \ - kernel \ - <<), \ - stream>>>( \ - reinterpret_cast(input.data.dptr), \ - reinterpret_cast(noop.data.dptr), \ - reinterpret_cast(cast_output->data.dptr), \ - reinterpret_cast(transposed_output->data.dptr), \ - reinterpret_cast(cast_output->scale.dptr), \ - reinterpret_cast(cast_output->amax.dptr), \ - row_length, num_rows, n_tiles); \ - } while (false) - -// Launch cast-transpose kernel for given vector sizes -#define LAUNCH_KERNEL_VEC_SIZES(load_size, store_size, InputType, OutputType) \ - do { \ - constexpr int nvec_in = load_size / sizeof(InputType); \ - constexpr int nvec_out = store_size / sizeof(OutputType); \ - \ - NVTE_CHECK(row_length % nvec_in == 0, "Unsupported shape."); \ - NVTE_CHECK(num_rows % nvec_out == 0, "Unsupported shape."); \ - \ - const size_t n_tiles = get_n_tiles(load_size, store_size); \ - const size_t n_blocks = get_n_blocks(n_tiles); \ - \ - const bool full_tile = row_length % (nvec_in * THREADS_PER_WARP) == 0 && \ - num_rows % (nvec_out * THREADS_PER_WARP) == 0; \ - \ - if (full_tile) { \ - LAUNCH_KERNEL(cast_transpose_kernel, \ - nvec_in, nvec_out, n_tiles, n_blocks, \ - InputType, OutputType); \ - } else { \ - LAUNCH_KERNEL(cast_transpose_kernel_notaligned, \ - nvec_in, nvec_out, n_tiles, n_blocks, \ - InputType, OutputType); \ - } \ - } while (false) + NVTE_CHECK(cast_output.data.shape[0] == num_rows, "Wrong dimension of cast output."); + NVTE_CHECK(cast_output.data.shape[1] == row_length, "Wrong dimension of cast output."); + NVTE_CHECK(transposed_output.data.shape[0] == row_length, + "Wrong dimension of transposed output."); + NVTE_CHECK(transposed_output.data.shape[1] == num_rows, + "Wrong dimension of transposed output."); + + // Check tensor pointers + NVTE_CHECK(input.data.dptr != nullptr, "Input is not allocated."); + NVTE_CHECK(cast_output.data.dptr != nullptr, "Cast output is not allocated."); + NVTE_CHECK(transposed_output.data.dptr != nullptr, "Transposed output is not allocated."); + NVTE_CHECK(cast_output.data.dtype == transposed_output.data.dtype, + "Cast and transposed output types must match."); + NVTE_CHECK(cast_output.amax.dptr == transposed_output.amax.dptr, + "Cast and transposed outputs need to share amax tensor."); + NVTE_CHECK(cast_output.scale.dptr == transposed_output.scale.dptr, + "Cast and transposed outputs need to share scale tensor."); TRANSFORMER_ENGINE_TYPE_SWITCH_INPUT(input.data.dtype, InputType, - TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output->data.dtype, OutputType, - - // Estimate number of SMs - // Note: H100 has 132 SMs, A100 has 108 SMs. - // Note: Directly querying number of SMs with cudaGetDeviceProperties is - // slow (>1 ms). Consider querying once and caching. - const int n_sms = 128; - - // Helper functions to get kernel configuration - auto get_n_tiles = [=] (size_t load_size, size_t store_size) -> int { - constexpr size_t threads_per_warp = static_cast(THREADS_PER_WARP); - size_t nvec_in = load_size / sizeof(InputType); - size_t nvec_out = store_size / sizeof(OutputType); - size_t n_tiles = DIVUP(row_length, nvec_in * threads_per_warp) * - DIVUP(num_rows, nvec_out * threads_per_warp); - return n_tiles; - }; - auto get_n_blocks = [=] (size_t n_tiles) -> int { - size_t n_warps_per_block = cast_transpose_num_threads / THREADS_PER_WARP; - size_t n_blocks = DIVUP(n_tiles * n_warps_per_tile, n_warps_per_block); - return n_blocks; - }; - - // Estimate optimal vector sizes and run - // Note: Consider reducing to 2B or 1B loads/stores for - // sufficiently small matrices. Need to consider whether reduced - // cache efficiency is worth increased SM utilization. Also need - // to keep in mind whether datatype can fit. - const size_t estimated_n_tiles = get_n_tiles(8, 8); - const size_t estimated_n_blocks = get_n_blocks(estimated_n_tiles); - if (estimated_n_blocks >= n_sms) { - LAUNCH_KERNEL_VEC_SIZES(8, 8, InputType, OutputType); - } else { - LAUNCH_KERNEL_VEC_SIZES(4, 4, InputType, OutputType); + TRANSFORMER_ENGINE_TYPE_SWITCH_OUTPUT(cast_output.data.dtype, OutputType, + constexpr const char *itype_name = TypeInfo::name; + constexpr const char *otype_name = TypeInfo::name; + constexpr size_t itype_size = sizeof(InputType); + constexpr size_t otype_size = sizeof(OutputType); + + // Choose between runtime-compiled or statically-compiled kernel + const bool aligned = (row_length % THREADS_PER_WARP == 0 + && num_rows % THREADS_PER_WARP == 0); + if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel + // Pick kernel config + std::vector kernel_configs; + kernel_configs.reserve(16); + auto add_config = [&](size_t load_size, size_t store_size) { + kernel_configs.emplace_back(row_length, num_rows, + itype_size, otype_size, + load_size, store_size); + }; + add_config(8, 8); + add_config(4, 8); add_config(8, 4); + add_config(4, 4); + add_config(2, 8); add_config(8, 2); + add_config(2, 4); add_config(4, 2); + add_config(2, 2); + add_config(1, 8); add_config(8, 1); + add_config(1, 4); add_config(4, 1); + add_config(1, 2); add_config(2, 1); + add_config(1, 1); + const auto &kernel_config = *std::min_element(kernel_configs.begin(), + kernel_configs.end()); + NVTE_CHECK(kernel_config.valid, "invalid kernel config"); + const size_t load_size = kernel_config.load_size; + const size_t store_size = kernel_config.store_size; + const size_t num_blocks = kernel_config.num_blocks; + + // Compile NVRTC kernel if needed and launch + auto& rtc_manager = rtc::KernelManager::instance(); + const std::string kernel_label = concat_strings("cast_transpose" + ",itype=", itype_name, + ",otype=", otype_name, + ",load_size=", load_size, + ",store_size=", store_size); + if (!rtc_manager.is_compiled(kernel_label)) { + std::string code = string_code_transpose_rtc_cast_transpose_cu; + code = regex_replace(code, "__ITYPE__", itype_name); + code = regex_replace(code, "__OTYPE__", otype_name); + code = regex_replace(code, "__LOAD_SIZE__", load_size); + code = regex_replace(code, "__STORE_SIZE__", store_size); + code = regex_replace(code, "__WARPS_PER_TILE__", warps_per_tile); + code = regex_replace(code, "__BLOCK_SIZE__", block_size); + rtc_manager.compile(kernel_label, + "cast_transpose_optimized_kernel", + code, + "transformer_engine/common/transpose/rtc/cast_transpose.cu"); + } + rtc_manager.launch(kernel_label, + num_blocks, block_size, 0, stream, + static_cast(input.data.dptr), + reinterpret_cast(noop.data.dptr), + static_cast(cast_output.data.dptr), + static_cast(transposed_output.data.dptr), + static_cast(cast_output.scale.dptr), + static_cast(cast_output.amax.dptr), + row_length, num_rows); + } else { // Statically-compiled general kernel + constexpr size_t load_size = 4; + constexpr size_t store_size = 4; + constexpr size_t row_tile_size = load_size / itype_size * THREADS_PER_WARP; + constexpr size_t col_tile_size = store_size / otype_size * THREADS_PER_WARP; + const int num_blocks = (DIVUP(row_length, row_tile_size) + * DIVUP(num_rows, col_tile_size)); + cast_transpose_general_kernel + <<>>( + static_cast(input.data.dptr), + reinterpret_cast(noop.data.dptr), + static_cast(cast_output.data.dptr), + static_cast(transposed_output.data.dptr), + static_cast(cast_output.scale.dptr), + static_cast(cast_output.amax.dptr), + row_length, num_rows); } - ); // NOLINT(*) ); // NOLINT(*) - -#undef LAUNCH_KERNEL -#undef LAUNCH_KERNEL_VEC_SIZES } } // namespace transformer_engine diff --git a/transformer_engine/common/transpose/rtc/cast_transpose.cu b/transformer_engine/common/transpose/rtc/cast_transpose.cu new file mode 100644 index 0000000000..d503581718 --- /dev/null +++ b/transformer_engine/common/transpose/rtc/cast_transpose.cu @@ -0,0 +1,129 @@ +/************************************************************************* + * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * + * See LICENSE for license information. + ************************************************************************/ + +#include "utils.cuh" + +using namespace transformer_engine; + +namespace { + +// Parameters +using CType = float; +using IType = __ITYPE__; +using OType = __OTYPE__; +constexpr size_t load_size = __LOAD_SIZE__; +constexpr size_t store_size = __STORE_SIZE__; +constexpr size_t warps_per_tile = __WARPS_PER_TILE__; +constexpr size_t block_size = __BLOCK_SIZE__; + +} // namespace + +__global__ void +__launch_bounds__(block_size) +cast_transpose_optimized_kernel(const IType * __restrict__ const input, + const CType * __restrict__ const noop, + OType * __restrict__ const output_c, + OType * __restrict__ const output_t, + const CType * __restrict__ const scale_ptr, + CType * __restrict__ const amax_ptr, + const size_t row_length, + const size_t num_rows) { + if (noop != nullptr && noop[0] == 1.0f) return; + + // Vectorized load/store sizes + constexpr size_t nvec_in = load_size / sizeof(IType); + constexpr size_t nvec_out = store_size / sizeof(OType); + using IVec = Vec; + using OVecC = Vec; + using OVecT = Vec; + + // Thread indices + // Note: Block is interpreted as a warp_size x num_warps grid + constexpr size_t bdimx = THREADS_PER_WARP; + constexpr size_t bdimy = warps_per_tile; + const size_t tid = threadIdx.x; + const size_t tidx = tid % bdimx; + const size_t tidy = tid / bdimx; + const size_t bid = blockIdx.x; + + // Input tensors are divided into tiles + // Note: Each tile is a warp_size x warp_size grid of nvec_out x nvec_in subtiles + constexpr size_t tile_dim_m = THREADS_PER_WARP * nvec_out; + constexpr size_t tile_dim_n = THREADS_PER_WARP * nvec_in; + + // Position of tile within tensor + const size_t num_tiles_m = num_rows / tile_dim_m; + const size_t tile_id_m = bid % num_tiles_m; + const size_t tile_id_n = bid / num_tiles_m; + const size_t tile_row = tile_id_m * tile_dim_m; + const size_t tile_col = tile_id_n * tile_dim_n; + + // Number of nvec_out x nvec_in subtiles for each thread to + // load/store + constexpr size_t num_iterations = THREADS_PER_WARP / warps_per_tile; + + // FP8 factors + const CType scale = scale_ptr == nullptr ? 1 : *scale_ptr; + CType amax = 0; + + // Load input to registers and transpose + // Note: Each thread loads num_iterations subtiles, computes amax, + // casts type, and transposes in registers. + OVecT local_output_t[nvec_in][num_iterations]; + #pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidy + iter * bdimy; + const size_t j1 = tidx; + #pragma unroll + for (size_t i2 = 0; i2 < nvec_out; ++i2) { + const size_t row = tile_row + i1 * nvec_out + i2; + const size_t col = tile_col + j1 * nvec_in; + IVec local_input; + OVecC local_output_c; + local_input.load_from(&input[row * row_length + col]); + #pragma unroll + for (size_t j2 = 0; j2 < nvec_in; ++j2) { + const CType in = static_cast(local_input.data.elt[j2]); + const OType out = OType(in * scale); + __builtin_assume(amax >= 0); + amax = fmaxf(fabsf(in), amax); + local_output_c.data.elt[j2] = out; + local_output_t[j2][iter].data.elt[i2] = out; + } + local_output_c.store_to(&output_c[row * row_length + col]); + } + } + + // Copy from registers to shared memory to global memory + __shared__ OVecT shared_output_t[THREADS_PER_WARP][THREADS_PER_WARP+1]; + #pragma unroll + for (size_t j2 = 0; j2 < nvec_in; ++j2) { + #pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidy + iter * bdimy; + const size_t j1 = tidx; + shared_output_t[j1][i1] = local_output_t[j2][iter]; + } + __syncthreads(); + #pragma unroll + for (size_t iter = 0; iter < num_iterations; ++iter) { + const size_t i1 = tidx; + const size_t j1 = tidy + iter * bdimy; + const size_t row = tile_row + i1 * nvec_out; + const size_t col = tile_col + j1 * nvec_in + j2; + shared_output_t[j1][i1].store_to(&output_t[col * num_rows + row]); + } + __syncthreads(); + } + + // Reduce amax over block + if (amax_ptr != nullptr) { + amax = reduce_max(amax, tidy); + if (threadIdx.x == 0) { + atomicMaxFloat(amax_ptr, amax); + } + } +} diff --git a/transformer_engine/common/transpose/transpose.cu b/transformer_engine/common/transpose/transpose.cu index 3ab83b944b..c0a1a7fbcf 100644 --- a/transformer_engine/common/transpose/transpose.cu +++ b/transformer_engine/common/transpose/transpose.cu @@ -6,13 +6,15 @@ #include #include + +#include + #include -#include -#include + #include "../common.h" -#include "../utils.cuh" -#include "../util/string.h" #include "../util/rtc.h" +#include "../util/string.h" +#include "../utils.cuh" namespace transformer_engine { @@ -25,7 +27,80 @@ namespace { constexpr size_t warps_per_tile = 4; constexpr size_t block_size = THREADS_PER_WARP * warps_per_tile; -} // namespace +/* Performance heuristics for optimized kernel parameters */ +struct KernelConfig { + /** Vector load size */ + size_t load_size; + /** Vector store size */ + size_t store_size; + + /* Whether config is valid */ + bool valid = false; + /* Number of CUDA blocks */ + size_t num_blocks = 0; + + /* Number of active SMs */ + size_t active_sm_count = 0; + /* Elements per L1 cache load */ + size_t elements_per_load = 0; + /* Elements per L1 cache store */ + size_t elements_per_store = 0; + + KernelConfig(size_t row_length, + size_t num_rows, + size_t type_size, + size_t load_size_, + size_t store_size_) + : load_size{load_size_} + , store_size{store_size_} { + // Check that tiles are correctly aligned + constexpr size_t cache_line_size = 128; + if (load_size % type_size != 0 + || store_size % type_size != 0 + || cache_line_size % type_size != 0) { + return; + } + const size_t row_tile_elements = load_size * THREADS_PER_WARP / type_size; + const size_t col_tile_elements = store_size * THREADS_PER_WARP / type_size; + valid = (row_length % row_tile_elements == 0 + && num_rows % col_tile_elements == 0); + if (!valid) { + return; + } + + // Number of CUDA blocks + num_blocks = (row_length / row_tile_elements) * (num_rows / col_tile_elements); + + // Parameters for performance model + constexpr size_t warps_per_sm = 16; // Rough estimate for saturated SMs + active_sm_count = std::min(DIVUP(num_blocks * warps_per_tile, warps_per_sm), + static_cast(cuda::sm_count())); + elements_per_load = (std::min(cache_line_size, row_tile_elements * type_size) + / type_size); + elements_per_store = (std::min(cache_line_size, col_tile_elements * type_size) + / type_size); + } + + /* Compare by estimated cost */ + bool operator<(const KernelConfig &other) const { + if (this->valid && other.valid) { + // cost ~ (1/elements_per_load + 1/elements_per_store) / active_sms + // Note: Integer arithmetic ensures stable ordering + const auto &l1 = this->elements_per_load; + const auto &s1 = this->elements_per_store; + const auto &p1 = this->active_sm_count; + const auto &l2 = other.elements_per_load; + const auto &s2 = other.elements_per_store; + const auto &p2 = other.active_sm_count; + const auto scale = l1 * s1 * p1 * l2 * s2 * p2; + const auto cost1 = (scale/l1 + scale/s1) / p1; + const auto cost2 = (scale/l2 + scale/s2) / p2; + return cost1 < cost2; + } else { + return this->valid && !other.valid; + } + } +}; template __global__ void @@ -127,6 +202,8 @@ transpose_general_kernel(const Type * __restrict__ const input, } } +} // namespace + void transpose(const Tensor &input, const Tensor &noop, Tensor *output_, @@ -170,82 +247,36 @@ void transpose(const Tensor &input, const bool aligned = (row_length % THREADS_PER_WARP == 0 && num_rows % THREADS_PER_WARP == 0); if (aligned && rtc::is_enabled()) { // Runtime-compiled tuned kernel - // Determine kernel config - size_t load_size = 8; - size_t store_size = 8; - auto is_tile_aligned = [&](size_t load_size_, size_t store_size_) -> bool { - return (row_length % (load_size / type_size * THREADS_PER_WARP) == 0 - && num_rows % (store_size / type_size * THREADS_PER_WARP) == 0); + // Pick kernel config + std::vector kernel_configs; + kernel_configs.reserve(16); + auto add_config = [&](size_t load_size, size_t store_size) { + kernel_configs.emplace_back(row_length, num_rows, type_size, + load_size, store_size); }; - auto num_blocks = [&](size_t load_size_, size_t store_size_) -> int { - const size_t row_tile_size = load_size_ / type_size * THREADS_PER_WARP; - const size_t col_tile_size = store_size_ / type_size * THREADS_PER_WARP; - return (row_length / row_tile_size) * (num_rows / col_tile_size); - }; - do { - const int sm_count = cuda::sm_count(); - - // Try maximizing SM occupancy without sacrificing cache - // efficiency - // Note: 32 threads/warp access 128B L1 cache line, so 4B - // loads/stores achieve full cache efficiency - if constexpr (type_size > 4) break; - if (is_tile_aligned(load_size, store_size) - && num_blocks(load_size, store_size) >= 4*sm_count) { - break; - } - load_size = 4; store_size = 8; - if (is_tile_aligned(load_size, store_size) - && num_blocks(load_size, store_size) >= 4*sm_count) { - break; - } - load_size = 4; store_size = 4; - if (is_tile_aligned(load_size, store_size) - && num_blocks(load_size, store_size) >= sm_count) { - break; - } - - // Simple performance model to balance SM occupancy and cache - // efficiency - auto cost = [&](int load_size_, int store_size_) -> double { - int active_sms = std::min(sm_count, num_blocks(load_size_, store_size_)); - // Amortize memory accesses over 128B L1 cache line - int elements_per_load = std::min(128, load_size_) / type_size; - int elements_per_store = std::min(128, store_size_) / type_size; - return (1.0 / elements_per_load + 1.0 / elements_per_store) / active_sms; - }; - if constexpr (type_size > 2) break; - if (is_tile_aligned(load_size, store_size) - && cost(2, 4) >= cost(load_size, store_size)) { - break; - } - load_size = 2; store_size = 4; - if (is_tile_aligned(load_size, store_size) - && cost(2, 2) >= cost(load_size, store_size)) { - break; - } - load_size = 2; store_size = 2; - if constexpr (type_size > 1) break; - if (is_tile_aligned(load_size, store_size) - && cost(1, 2) >= cost(load_size, store_size)) { - break; - } - load_size = 1; store_size = 2; - if (is_tile_aligned(load_size, store_size) - && cost(1, 1) >= cost(load_size, store_size)) { - break; - } - load_size = 1; store_size = 1; - } while (false); - NVTE_CHECK(is_tile_aligned(load_size, store_size), - "memory accesses are not properly aligned"); + add_config(8, 8); + add_config(4, 8); add_config(8, 4); + add_config(4, 4); + add_config(2, 8); add_config(8, 2); + add_config(2, 4); add_config(4, 2); + add_config(2, 2); + add_config(1, 8); add_config(8, 1); + add_config(1, 4); add_config(4, 1); + add_config(1, 2); add_config(2, 1); + add_config(1, 1); + const auto &kernel_config = *std::min_element(kernel_configs.begin(), + kernel_configs.end()); + NVTE_CHECK(kernel_config.valid, "invalid kernel config"); + const size_t load_size = kernel_config.load_size; + const size_t store_size = kernel_config.store_size; + const size_t num_blocks = kernel_config.num_blocks; // Compile NVRTC kernel if needed and launch auto& rtc_manager = rtc::KernelManager::instance(); const std::string kernel_label = concat_strings("transpose" ",type=", type_name, ",load_size=", load_size, - ",store_size", store_size); + ",store_size=", store_size); if (!rtc_manager.is_compiled(kernel_label)) { std::string code = string_code_transpose_rtc_transpose_cu; code = regex_replace(code, "__TYPE__", type_name); @@ -259,7 +290,7 @@ void transpose(const Tensor &input, "transformer_engine/common/transpose/rtc/transpose.cu"); } rtc_manager.launch(kernel_label, - num_blocks(load_size, store_size), block_size, 0, stream, + num_blocks, block_size, 0, stream, static_cast(input.data.dptr), static_cast(noop.data.dptr), static_cast(output.data.dptr),