Skip to content

Commit

Permalink
[common] Generalized MXFP8 gated kernels w.r.t. input tensor dimensio…
Browse files Browse the repository at this point in the history
…ns (#1449)

* Fixed scaling tensor alignment/padding

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Changes from review

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed alignment and padding in scaled tensors. Refactoring.

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>

* Skipped scenarios for non-mod(32) tensors

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixes

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

* More fixes

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

* Some fixes to the CPU reference

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fixed typo in the kernel. Restricted the last dim to multiples of 32

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>

* Fixed TMA writes overlap

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Remove the largest test cases for numerical stability

Signed-off-by: Przemek Tredak <ptredak@nvidia.com>

---------

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Przemek Tredak <ptredak@nvidia.com>
Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
  • Loading branch information
4 people authored Feb 5, 2025
1 parent 6af5ca3 commit ce8b127
Show file tree
Hide file tree
Showing 11 changed files with 386 additions and 226 deletions.
4 changes: 4 additions & 0 deletions tests/cpp/operator/test_cast_gated_swiglu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,10 @@ TEST_P(CastSwiGLUTestSuite, TestCastSwiGLU) {
const DType output_type = std::get<1>(GetParam());
const auto size = std::get<2>(GetParam());

if (size.back() % 32 != 0) {
GTEST_SKIP();
}

TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
input_type, InputType,
TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
Expand Down
52 changes: 19 additions & 33 deletions tests/cpp/operator/test_cast_mxfp8.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,6 @@
* See LICENSE for license information.
************************************************************************/

#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>
#include <limits>

#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
Expand Down Expand Up @@ -187,19 +180,14 @@ void performTest_x1(const ProcessingMethod processing_method,

const size_t block_size_rows = rowwise ? 1 : 32;
const size_t block_size_cols = colwise ? 1 : 32;
const size_t unpadded_blocks_Y = (rows + block_size_rows - 1) / block_size_rows;
const size_t unpadded_blocks_X = (cols + block_size_cols - 1) / block_size_cols;

const size_t block_alignment_X = rowwise
? scale_tensor_alignment_X_rowwise
: scale_tensor_alignment_X_colwise;
const size_t block_alignment_Y = rowwise
? scale_tensor_alignment_Y_rowwise
: scale_tensor_alignment_Y_colwise;

// Roundup to the nearest multiple
const size_t blocks_Y = ((unpadded_blocks_Y + block_alignment_Y - 1) / block_alignment_Y) * block_alignment_Y;
const size_t blocks_X = ((unpadded_blocks_X + block_alignment_X - 1) / block_alignment_X) * block_alignment_X;

const std::array<size_t,4> scale_dims = get_scale_tensor_dims(rows, cols, block_size_rows,
block_size_cols);

const size_t unpadded_blocks_Y = scale_dims[0];
const size_t unpadded_blocks_X = scale_dims[1];
const size_t blocks_Y = scale_dims[2];
const size_t blocks_X = scale_dims[3];
const size_t scales_stride = blocks_X;

Tensor input(shape, itype);
Expand Down Expand Up @@ -325,21 +313,19 @@ void performTest_x2(const ProcessingMethod processing_method,
const size_t rows = first_dimension(shape);
const size_t cols = last_dimension(shape);

const size_t unpadded_blocks_Y_rowwise = rows;
const size_t unpadded_blocks_X_rowwise = divide_round_up(cols, block_size_cols);
const size_t unpadded_blocks_Y_colwise = divide_round_up(rows, block_size_rows);
const size_t unpadded_blocks_X_colwise = cols;

const size_t blocks_Y_rowwise = round_up_to_nearest_multiple(unpadded_blocks_Y_rowwise,
scale_tensor_alignment_Y_rowwise);
const size_t blocks_X_rowwise = round_up_to_nearest_multiple(unpadded_blocks_X_rowwise,
scale_tensor_alignment_X_rowwise);
const size_t blocks_Y_colwise = round_up_to_nearest_multiple(unpadded_blocks_Y_colwise,
scale_tensor_alignment_Y_colwise);
const size_t blocks_X_colwise = round_up_to_nearest_multiple(unpadded_blocks_X_colwise,
scale_tensor_alignment_X_colwise);
const std::array<size_t,4> scale_dims_rowwise = get_scale_tensor_dims(rows, cols, 1, 32);
const std::array<size_t,4> scale_dims_colwise = get_scale_tensor_dims(rows, cols, 32, 1);

const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0];
const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1];
const size_t blocks_Y_rowwise = scale_dims_rowwise[2];
const size_t blocks_X_rowwise = scale_dims_rowwise[3];
const size_t scales_stride_rowwise = blocks_X_rowwise;

const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0];
const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1];
const size_t blocks_Y_colwise = scale_dims_colwise[2];
const size_t blocks_X_colwise = scale_dims_colwise[3];
const size_t scales_stride_colwise = blocks_X_colwise;

Tensor input(shape, itype);
Expand Down
147 changes: 104 additions & 43 deletions tests/cpp/operator/test_cast_mxfp8_gated_swiglu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -4,12 +4,6 @@
* See LICENSE for license information.
************************************************************************/

#include <cstring>
#include <iomanip>
#include <iostream>
#include <memory>
#include <random>

#include <cuda_bf16.h>
#include <cuda_fp8.h>
#include <cuda_runtime.h>
Expand All @@ -30,6 +24,7 @@ void scale_block(const IType* grad,
OType* output,
fp8e8m0* output_scales,
const size_t scale_idx,
const size_t scale_idx_gate,
float& thread_amax,
const size_t i_min,
const size_t i_max,
Expand All @@ -38,32 +33,45 @@ void scale_block(const IType* grad,
const size_t cols) {

float block_amax = 0.0f;
float block_amax_gate = 0.0f;
const size_t stride = cols * 2;

// Find the absolute maximum value in the block
for (size_t i = i_min; i < i_max; ++i) {
for (size_t j = j_min; j < j_max; ++j) {
float silu_elt = static_cast<float>(input[i * stride + j]);
float gate_elt = static_cast<float>(input[i * stride + cols + j]);
float gated_amax;
float gated_amax_act = 0;
float gated_amax_gate = 0;

if constexpr (IS_DGATED) {
const float grad_elt = static_cast<float>(grad[i * cols + j]);
const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
const float after_dgate = silu(silu_elt) * grad_elt;
gated_amax = max(abs(after_dsilu), abs(after_dgate));
gated_amax_act = abs(after_dsilu);
gated_amax_gate = abs(after_dgate);
} else {
const float after_silu = silu(silu_elt) * gate_elt;
gated_amax = abs(after_silu);
gated_amax_act = abs(after_silu);
}

if (abs(gated_amax) > block_amax) { block_amax = abs(gated_amax); }
if (gated_amax_act > block_amax) { block_amax = gated_amax_act; }
if (gated_amax_gate > block_amax_gate) { block_amax_gate = gated_amax_gate; }
}
}

const fp8e8m0 biased_exponent = float_to_e8m0(block_amax * Quantized_Limits<OType>::max_reciprocal());
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax *
Quantized_Limits<OType>::max_reciprocal());
const float scale_reciprocal = exp2f_rcp(biased_exponent);
output_scales[scale_idx] = biased_exponent;
float scale_reciprocal_gate = 1;
if constexpr (IS_DGATED) {
const fp8e8m0 biased_exponent = float_to_e8m0(block_amax_gate *
Quantized_Limits<OType>::max_reciprocal());
scale_reciprocal_gate = exp2f_rcp(biased_exponent);
output_scales[scale_idx_gate] = biased_exponent;
}


// Quantize elements in the block
for (size_t i = i_min; i < i_max; ++i) {
Expand All @@ -76,7 +84,8 @@ void scale_block(const IType* grad,
const float after_dsilu = dsilu(silu_elt) * grad_elt * gate_elt;
const float after_dgate = silu(silu_elt) * grad_elt;
output[i * stride + j] = static_cast<OType>(after_dsilu * scale_reciprocal);
output[i * stride + cols + j] = static_cast<OType>(after_dgate * scale_reciprocal);
output[i * stride + cols + j] = static_cast<OType>(after_dgate *
scale_reciprocal_gate);
} else {
const float after_silu = silu(silu_elt) * gate_elt;
output[i * cols + j] = static_cast<OType>(after_silu * scale_reciprocal);
Expand All @@ -85,6 +94,7 @@ void scale_block(const IType* grad,
}
}
thread_amax = std::max(thread_amax, block_amax);
thread_amax = std::max(thread_amax, block_amax_gate);
}

template <bool IS_DGATED, typename IType, typename OType>
Expand All @@ -96,14 +106,14 @@ void compute_ref_x1(const IType* grad,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X) {
const size_t block_size_X,
const size_t scales_stride) {
const size_t tile_size_Y = std::max(32lu, block_size_Y);
const size_t tile_size_X = std::max(64lu, block_size_X);
const size_t tiles_num_Y = (rows + tile_size_Y - 1) / tile_size_Y;
const size_t tiles_num_X = (cols + tile_size_X - 1) / tile_size_X;
const size_t blocks_per_tile_Y = tile_size_Y / block_size_Y;
const size_t blocks_per_tile_X = tile_size_X / block_size_X;
const size_t blocks_per_row = (cols + block_size_X - 1) / block_size_X;

float amax = 0;
#pragma omp parallel reduction(max: amax) proc_bind(spread)
Expand All @@ -120,17 +130,21 @@ void compute_ref_x1(const IType* grad,
const size_t block_idx_Y = tile_Y * blocks_per_tile_Y + ii;
const size_t block_offset_Y = ii * block_size_Y;
const size_t i_min = tile_offset_Y + block_offset_Y;
if (i_min >= rows) continue;
const size_t i_max = std::min(i_min + block_size_Y, rows);

for (size_t jj = 0; jj < blocks_per_tile_X; ++jj) {
const size_t block_idx_X = tile_X * blocks_per_tile_X + jj;
const size_t block_offset_X = jj * block_size_X;
const size_t j_min = tile_offset_X + block_offset_X;
if (j_min >= cols) continue;
const size_t j_max = std::min(j_min + block_size_X, cols);

const size_t mx_scale_idx = block_idx_Y * blocks_per_row + block_idx_X;
const size_t mx_scale_idx = block_idx_Y * scales_stride + block_idx_X;
const size_t mx_scale_idx_gate = block_idx_Y * scales_stride + block_idx_X +
cols / block_size_X;
scale_block<IS_DGATED, IType, OType>(
grad, input, output, output_scales, mx_scale_idx,
grad, input, output, output_scales, mx_scale_idx, mx_scale_idx_gate,
thread_amax, i_min, i_max, j_min, j_max, cols);
}
}
Expand All @@ -153,11 +167,13 @@ void compute_ref_x2(const IType* grad,
const size_t rows,
const size_t cols,
const size_t block_size_Y,
const size_t block_size_X) {
const size_t block_size_X,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise) {
compute_ref_x1<IS_DGATED, IType, OType>(
grad, input, output_rowwise, scales_rowwise, ref_amax, rows, cols, 1, block_size_X);
grad, input, output_rowwise, scales_rowwise, ref_amax, rows, cols, 1, block_size_X, scales_stride_rowwise);
compute_ref_x1<IS_DGATED, IType, OType>(
grad, input, output_colwise, scales_colwise, ref_amax, rows, cols, block_size_Y, 1);
grad, input, output_colwise, scales_colwise, ref_amax, rows, cols, block_size_Y, 1, scales_stride_colwise);
}

/**
Expand All @@ -167,7 +183,6 @@ void compute_ref_x2(const IType* grad,
* OR
* 2) Scaled columns + column-wise scaling factors
*/

template <bool IS_DGATED, typename IType, typename OType>
void performTest_x1(const size_t rows,
const size_t cols,
Expand All @@ -179,24 +194,39 @@ void performTest_x1(const size_t rows,
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;

bool rowwise = false, colwise = false;
if (block_size_rows == 1 && block_size_cols == 32) rowwise = true;
if (block_size_rows == 32 && block_size_cols == 1) colwise = true;
const bool rowwise = (block_size_rows == 1) && (block_size_cols == 32);
const bool colwise = (block_size_rows == 32) && (block_size_cols == 1);
NVTE_CHECK(rowwise || colwise);

const size_t blocks_Y = (rows + block_size_rows - 1) / block_size_rows;
const size_t blocks_X = (cols + block_size_cols - 1) / block_size_cols;
const size_t blocks_num = blocks_Y * blocks_X;
// std::cout << "unpadded_blocks_Y: " << unpadded_blocks_Y << std::endl;
// std::cout << "unpadded_blocks_X: " << unpadded_blocks_X << std::endl;
// std::cout << "blocks_Y: " << blocks_Y << std::endl;
// std::cout << "blocks_X: " << blocks_X << std::endl;
// std::cout << "scales_stride: " << scales_stride << std::endl;

Tensor grad({ rows, cols }, itype);
Tensor input({ rows, cols * 2 }, itype);

const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;

const std::array<size_t,4> scale_dims = get_scale_tensor_dims(rows, output_cols, block_size_rows,
block_size_cols);

const size_t unpadded_blocks_Y = scale_dims[0];
const size_t unpadded_blocks_X = scale_dims[1];
const size_t blocks_Y = scale_dims[2];
const size_t blocks_X = scale_dims[3];
const size_t scales_stride = blocks_X;

Tensor output(std::vector<size_t>{ rows, output_cols }, otype, rowwise, colwise, NVTE_MXFP8_1D_SCALING);

std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(rows * output_cols);
std::unique_ptr<fp8e8m0[]> ref_output_scales = std::make_unique<fp8e8m0[]>(blocks_Y * blocks_X);

for (size_t i = 0; i < blocks_Y * blocks_X; ++i) {
ref_output_scales[i] = 0;
}

// fillCase<EncodingType>(&grad, fill_case);
if constexpr (IS_DGATED) {
fillUniform(&grad);
Expand All @@ -222,14 +252,21 @@ void performTest_x1(const size_t rows,
rows,
cols,
block_size_rows,
block_size_cols);
block_size_cols,
scales_stride);

auto [atol, rtol] = getTolerances(otype);
compareResults("output", output, ref_output.get(), rowwise, atol, rtol);

const uint8_t * const gpu_scales_ptr = rowwise
? output.rowwise_cpu_scale_inv_ptr<fp8e8m0>()
: output.columnwise_cpu_scale_inv_ptr<fp8e8m0>();
if (rowwise) {
compare_e8m0_scaling_factors("scales", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(), ref_output_scales.get(), blocks_num);
compare_e8m0_scaling_factors("rowwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride);
} else {
compare_e8m0_scaling_factors("scales", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(), ref_output_scales.get(), blocks_num);
compare_e8m0_scaling_factors("colwise scales", gpu_scales_ptr, ref_output_scales.get(),
unpadded_blocks_Y, unpadded_blocks_X, scales_stride);
}
}

Expand All @@ -251,21 +288,39 @@ void performTest_x2(const size_t rows,
DType itype = TypeInfo<IType>::dtype;
DType otype = TypeInfo<OType>::dtype;

const size_t blocks_Y = (rows + block_size_rows - 1) / block_size_rows;
const size_t blocks_X = (cols + block_size_cols - 1) / block_size_cols;
const size_t blocks_num_rowwise = rows * blocks_X;
const size_t blocks_num_colwise = blocks_Y * cols;

Tensor grad({ rows, cols }, itype);
Tensor input({ rows, cols * 2 }, itype);

const size_t output_cols = (IS_DGATED ? 2 : 1) * cols;

const std::array<size_t,4> scale_dims_rowwise = get_scale_tensor_dims(rows, output_cols, 1, 32);
const std::array<size_t,4> scale_dims_colwise = get_scale_tensor_dims(rows, output_cols, 32, 1);

const size_t unpadded_blocks_Y_rowwise = scale_dims_rowwise[0];
const size_t unpadded_blocks_X_rowwise = scale_dims_rowwise[1];
const size_t blocks_Y_rowwise = scale_dims_rowwise[2];
const size_t blocks_X_rowwise = scale_dims_rowwise[3];
const size_t scales_stride_rowwise = blocks_X_rowwise;

const size_t unpadded_blocks_Y_colwise = scale_dims_colwise[0];
const size_t unpadded_blocks_X_colwise = scale_dims_colwise[1];
const size_t blocks_Y_colwise = scale_dims_colwise[2];
const size_t blocks_X_colwise = scale_dims_colwise[3];
const size_t scales_stride_colwise = blocks_X_colwise;

Tensor output(std::vector<size_t>{ rows, output_cols }, otype, true, true, NVTE_MXFP8_1D_SCALING);

std::unique_ptr<OType[]> ref_output_rowwise = std::make_unique<OType[]>(rows * output_cols);
std::unique_ptr<OType[]> ref_output_colwise = std::make_unique<OType[]>(rows * output_cols);
std::unique_ptr<fp8e8m0[]> ref_scales_rowwise = std::make_unique<fp8e8m0[]>(rows * blocks_X);
std::unique_ptr<fp8e8m0[]> ref_scales_colwise = std::make_unique<fp8e8m0[]>(blocks_Y * cols);
std::unique_ptr<fp8e8m0[]> ref_scales_rowwise = std::make_unique<fp8e8m0[]>(blocks_Y_rowwise * blocks_X_rowwise);
std::unique_ptr<fp8e8m0[]> ref_scales_colwise = std::make_unique<fp8e8m0[]>(blocks_Y_colwise * blocks_X_colwise);

for (size_t i = 0; i < blocks_Y_rowwise * blocks_X_rowwise; ++i) {
ref_scales_rowwise[i] = 0;
}
for (size_t i = 0; i < blocks_Y_colwise * blocks_X_colwise; ++i) {
ref_scales_colwise[i] = 0;
}

// fillCase<EncodingType>(&grad, fill_case);
if constexpr (IS_DGATED) {
Expand Down Expand Up @@ -294,26 +349,32 @@ void performTest_x2(const size_t rows,
rows,
cols,
block_size_rows,
block_size_cols);
block_size_cols,
scales_stride_rowwise,
scales_stride_colwise);

auto [atol, rtol] = getTolerances(otype);
auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
compareResults("output_c_rowwise", output, ref_output_rowwise.get(), true, atol, rtol);
compareResults("output_c_colwise", output, ref_output_colwise.get(), false, atol, rtol);
compare_e8m0_scaling_factors("scales_rowwise", output.rowwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_rowwise.get(), blocks_num_rowwise);
ref_scales_rowwise.get(), unpadded_blocks_Y_rowwise,
unpadded_blocks_X_rowwise, scales_stride_rowwise);
compare_e8m0_scaling_factors("scales_colwise", output.columnwise_cpu_scale_inv_ptr<fp8e8m0>(),
ref_scales_colwise.get(), blocks_num_colwise);
ref_scales_colwise.get(), unpadded_blocks_Y_colwise,
unpadded_blocks_X_colwise, scales_stride_colwise);
}

std::vector<std::pair<size_t, size_t>> matrix_sizes = {
{1, 32},
{16, 64},
{65, 96},
{128, 128},
{256, 256},
{993, 512},
{768, 1024},
{256, 65536},
// {2048, 12288},
// {65536, 128},
// {16384, 6144},
{65536, 128},
{16384, 1632},
};

std::vector<std::pair<size_t, size_t>> block_sizes = {
Expand Down
Loading

0 comments on commit ce8b127

Please sign in to comment.