Skip to content

Commit

Permalink
[pre-commit.ci] auto fixes from pre-commit.com hooks
Browse files Browse the repository at this point in the history
for more information, see https://pre-commit.ci
  • Loading branch information
pre-commit-ci[bot] committed Feb 4, 2025
1 parent 01955fc commit a90dc70
Show file tree
Hide file tree
Showing 4 changed files with 18 additions and 17 deletions.
3 changes: 2 additions & 1 deletion transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,8 @@ constexpr size_t scale_tensor_alignment_Y_colwise = 4;
size_t typeToSize(const DType type);

void CheckNoopTensor(const Tensor &t, const std::string &name);
void CheckInputTensor(const Tensor &t, const std::string &name, const bool is_gated_mxfp8_tensor = false);
void CheckInputTensor(const Tensor &t, const std::string &name,
const bool is_gated_mxfp8_tensor = false);
void CheckOutputTensor(const Tensor &t, const std::string &name, bool allow_empty = false,
const bool is_gated_mxfp8_tensor = false);

Expand Down
23 changes: 11 additions & 12 deletions transformer_engine/common/transformer_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,13 +71,12 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name,
if (is_tensor_scaling(t.scaling_mode)) {
// per-tensor scaling
if (t.has_data()) {
NVTE_CHECK(t.scale_inv.numel() == 1,
"Tensor \"", name, "\" has invalid scale_inv shape (expected (1), got ",
t.scale_inv.shape, ")");
NVTE_CHECK(t.scale_inv.numel() == 1, "Tensor \"", name,
"\" has invalid scale_inv shape (expected (1), got ", t.scale_inv.shape, ")");
}
if (t.has_columnwise_data()) {
NVTE_CHECK(t.columnwise_scale_inv.numel() == 1,
"Tensor \"", name, "\" has invalid columnwise_scale_inv shape (expected (1), got ",
NVTE_CHECK(t.columnwise_scale_inv.numel() == 1, "Tensor \"", name,
"\" has invalid columnwise_scale_inv shape (expected (1), got ",
t.columnwise_scale_inv.shape, ")");
}
} else {
Expand All @@ -87,27 +86,27 @@ void CheckScaleTensorShape(const Tensor &t, const std::string &name,
size_t expected_x, expected_y, alignment;

const size_t flat_first_dim = t.flat_first_dim();
const size_t flat_last_dim = is_gated_mxfp8_op ? (t.flat_last_dim() / 2): t.flat_last_dim();
const size_t flat_last_dim = is_gated_mxfp8_op ? (t.flat_last_dim() / 2) : t.flat_last_dim();

if (t.has_data()) {
alignment = block_alignment[0];
expected_x = DIVUP(DIVUP(flat_first_dim, static_cast<size_t>(1)), alignment) * alignment;
alignment = block_alignment[1];
expected_y = DIVUP(DIVUP(flat_last_dim, static_cast<size_t>(32)), alignment) * alignment;
const auto &expected = std::vector<size_t>{expected_x, expected_y};
NVTE_CHECK(t.scale_inv.shape == expected,
"Tensor \"", name, "\" has invalid scale_inv shape (expected ",
expected, ", got ", t.scale_inv.shape, ")");
NVTE_CHECK(t.scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid scale_inv shape (expected ", expected, ", got ",
t.scale_inv.shape, ")");
}
if (t.has_columnwise_data()) {
alignment = block_alignment[1];
expected_x = DIVUP(DIVUP(flat_first_dim, static_cast<size_t>(32)), alignment) * alignment;
alignment = block_alignment[0];
expected_y = DIVUP(DIVUP(flat_last_dim, static_cast<size_t>(1)), alignment) * alignment;
const auto &expected = std::vector<size_t>{expected_x, expected_y};
NVTE_CHECK(t.columnwise_scale_inv.shape == expected,
"Tensor \"", name, "\" has invalid columnwise_scale_inv shape (expected ",
expected, ", got ", t.columnwise_scale_inv.shape, ")");
NVTE_CHECK(t.columnwise_scale_inv.shape == expected, "Tensor \"", name,
"\" has invalid columnwise_scale_inv shape (expected ", expected, ", got ",
t.columnwise_scale_inv.shape, ")");
}
}
}
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/common/util/cast_gated_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -701,9 +701,9 @@ void cast_fp8_gated(const Tensor &grad, const Tensor &gated_input, Tensor *outpu
create_2D_tensor_map(tensor_map_grad, grad.data, rows, cols, SHMEM_DIM_Y, SHMEM_DIM_X,
sizeof(IType));
}

create_2D_tensor_map(tensor_map_gated_input, gated_input.data, rows, cols * 2,
SHMEM_DIM_Y, SHMEM_DIM_X, sizeof(IType));
SHMEM_DIM_Y, SHMEM_DIM_X, sizeof(IType));
create_2D_tensor_map(tensor_map_output, output->data, rows, output_cols, SHMEM_DIM_Y,
SHMEM_DIM_X, sizeof(OType));

Expand Down
5 changes: 3 additions & 2 deletions transformer_engine/common/util/cast_kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -938,8 +938,9 @@ void mxfp8_quantize(const Tensor &input, const Tensor *act_input,
const size_t blocks_X = DIVUP(chunks_X, MXFP8_CHUNKS_PER_BLOCK_X);

const size_t scale_stride_rowwise = use_rowwise_scaling ? output->scale_inv.shape[1] : 1;
const size_t scale_stride_colwise = use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1;

const size_t scale_stride_colwise =
use_colwise_scaling ? output->columnwise_scale_inv.shape[1] : 1;

e8m0_t *const scales_rowwise_ptr =
use_rowwise_scaling ? reinterpret_cast<e8m0_t *>(output->scale_inv.dptr) : nullptr;
e8m0_t *const scales_colwise_ptr =
Expand Down

0 comments on commit a90dc70

Please sign in to comment.