Skip to content

Commit

Permalink
Add path to disable cudnn norm for mxfp8 (#1432)
Browse files Browse the repository at this point in the history
* Add path to disable cudnn norm for mxfp8

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
  • Loading branch information
ksivaman and pre-commit-ci[bot] authored Jan 28, 2025
1 parent b653134 commit fb1a241
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 19 deletions.
32 changes: 16 additions & 16 deletions tests/cpp/run_norm_tests.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,26 +10,26 @@ fi
mkdir -p outputs
OUT="outputs/$OUTPUT_FILE"

echo "NVTE_FWD_LAYERNORM_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*LN*.*X0" >> $OUT
NVTE_FWD_LAYERNORM_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*LN*.*X0 >> $OUT
echo "NVTE_NORM_FWD_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*LN*.*X0" >> $OUT
NVTE_NORM_FWD_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*LN*.*X0 >> $OUT

echo "NVTE_FWD_LAYERNORM_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*LN*.*X1" >> $OUT
NVTE_FWD_LAYERNORM_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*LN*.*X1 >> $OUT
echo "NVTE_NORM_FWD_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*LN*.*X1" >> $OUT
NVTE_NORM_FWD_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*LN*.*X1 >> $OUT

echo "NVTE_BWD_LAYERNORM_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*LN*.*X0" >> $OUT
NVTE_BWD_LAYERNORM_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*LN*.*X0 >> $OUT
echo "NVTE_NORM_BWD_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*LN*.*X0" >> $OUT
NVTE_NORM_BWD_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*LN*.*X0 >> $OUT

echo "NVTE_BWD_LAYERNORM_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*LN*.*X1" >> $OUT
NVTE_BWD_LAYERNORM_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*LN*.*X1 >> $OUT
echo "NVTE_NORM_BWD_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*LN*.*X1" >> $OUT
NVTE_NORM_BWD_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*LN*.*X1 >> $OUT

echo "NVTE_FWD_RMSNORM_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*RMS*.*X0" >> $OUT
NVTE_FWD_RMSNORM_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*RMS*.*X0 >> $OUT
echo "NVTE_NORM_FWD_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*RMS*.*X0" >> $OUT
NVTE_NORM_FWD_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*RMS*.*X0 >> $OUT

echo "NVTE_FWD_RMSNORM_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*RMS*.*X1" >> $OUT
NVTE_FWD_RMSNORM_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*RMS*.*X1 >> $OUT
echo "NVTE_NORM_FWD_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*RMS*.*X1" >> $OUT
NVTE_NORM_FWD_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*RMS*.*X1 >> $OUT

echo "NVTE_BWD_RMSNORM_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*RMS*.*X0" >> $OUT
NVTE_BWD_RMSNORM_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*RMS*.*X0 >> $OUT
echo "NVTE_NORM_BWD_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*RMS*.*X0" >> $OUT
NVTE_NORM_BWD_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*RMS*.*X0 >> $OUT

echo "NVTE_BWD_RMSNORM_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*RMS*.*X1" >> $OUT
NVTE_BWD_RMSNORM_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*RMS*.*X1 >> $OUT
echo "NVTE_NORM_BWD_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*RMS*.*X1" >> $OUT
NVTE_NORM_BWD_USE_CUDNN=1 ./build/operator/test_operator --gtest_filter=*RMS*.*X1 >> $OUT
19 changes: 16 additions & 3 deletions transformer_engine/pytorch/module/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

"""Internal function used by multiple modules."""

import os
from typing import Any, List, Optional, Tuple, Union, Callable
from dataclasses import dataclass

Expand All @@ -12,6 +13,10 @@
from .. import cpp_extensions as tex
from ..constants import TE_DType
from ..utils import get_default_init_method
from ..tensor.mxfp8_tensor import MXFP8Quantizer


_use_cudnn_mxfp8_norm = bool(int(os.getenv("NVTE_CUDNN_MXFP8_NORM", "1")))


def _get_normalization_func(normalization: str, forward: bool):
Expand Down Expand Up @@ -46,17 +51,25 @@ def apply_normalization(

inputs = (inputmat, ln_weight) if ln_bias is None else (inputmat, ln_weight, ln_bias)

split_mxfp8_cast = False
if not _use_cudnn_mxfp8_norm and isinstance(output_quantizer, MXFP8Quantizer):
split_mxfp8_cast = True

output = normalization_func(
*inputs,
eps,
ln_out,
output_quantizer,
None if split_mxfp8_cast else ln_out,
None if split_mxfp8_cast else output_quantizer,
TE_DType[output_dtype] if output_dtype in TE_DType else output_dtype,
fwd_ln_sm_margin,
zero_centered_gamma,
)

return output
return (
(output_quantizer.quantize(output[0], out=ln_out), *output[1:])
if split_mxfp8_cast
else output
)


class _NoopCatFunc(torch.autograd.Function):
Expand Down

0 comments on commit fb1a241

Please sign in to comment.