Skip to content

Commit

Permalink
Rename block scaling recipe (#1442)
Browse files Browse the repository at this point in the history
Rename MXFP8 recipe

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
  • Loading branch information
ksivaman authored Jan 31, 2025
1 parent b5e6657 commit 5955f7e
Show file tree
Hide file tree
Showing 10 changed files with 37 additions and 37 deletions.
4 changes: 2 additions & 2 deletions tests/pytorch/distributed/run_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import torch.distributed as dist

from transformer_engine.common.recipe import (
BlockScaling,
MXFP8BlockScaling,
DelayedScaling,
Format,
Recipe,
Expand Down Expand Up @@ -44,7 +44,7 @@ def quantization_recipe() -> Recipe:
fp8_format=Format.HYBRID, amax_history_len=32, amax_compute_algo="max"
)
if QUANTIZATION == "mxfp8":
return BlockScaling()
return MXFP8BlockScaling()
return te.fp8.get_default_fp8_recipe()


Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/distributed/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.BlockScaling(
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
Expand Down
4 changes: 2 additions & 2 deletions tests/pytorch/test_cuda_graphs.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ class ModelConfig:

fp8_recipes = [
recipe.DelayedScaling(),
recipe.BlockScaling(),
recipe.MXFP8BlockScaling(),
]

# Supported data types
Expand Down Expand Up @@ -315,7 +315,7 @@ def test_make_graphed_callables(
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_weight_caching and not fp8:
pytest.skip("FP8 needed for FP8 parameters.")
if fp8_recipe.block() and not mxfp8_available:
if fp8_recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

# Run model with different CUDA graph settings.
Expand Down
2 changes: 1 addition & 1 deletion tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,7 @@ def make_recipe(name: Optional[str] = None) -> Optional[Recipe]:
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
if name == "mxfp8":
return transformer_engine.common.recipe.BlockScaling(
return transformer_engine.common.recipe.MXFP8BlockScaling(
fp8_format=transformer_engine.common.recipe.Format.E4M3,
)
raise ValueError(f"Unsupported quantization scheme ({name})")
Expand Down
18 changes: 9 additions & 9 deletions tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,7 @@ def __init__(self, hidden_size, eps, num_attention_heads, embed, num_layers, seq
mask_types = ["causal", "no_mask"]

fp8_recipes = [
recipe.BlockScaling(),
recipe.MXFP8BlockScaling(),
recipe.DelayedScaling(),
]

Expand Down Expand Up @@ -556,7 +556,7 @@ def _test_e2e_selective_recompute(
def test_gpt_selective_activation_recompute(dtype, bs, model, fp8, recipe, fp8_model_params):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.block() and not mxfp8_available:
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

config = model_configs[model]
Expand Down Expand Up @@ -668,7 +668,7 @@ def test_gpt_full_activation_recompute(
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.block() and not mxfp8_available:
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

config = model_configs[model]
Expand Down Expand Up @@ -1418,7 +1418,7 @@ def _test_grouped_linear_accuracy(block, num_gemms, bs, dtype, config, recipe, f
if fp8:
if recipe.delayed():
split_size = 16
if recipe.block():
if recipe.mxfp8():
split_size = 128
m = config.seq_len // split_size
dist = torch.sort(torch.randint(0, m, (num_gemms - 2,))).values.tolist()
Expand Down Expand Up @@ -1463,9 +1463,9 @@ def test_grouped_linear_accuracy(
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.block() and not mxfp8_available:
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.block(): # TODO(ksivamani): debug mismatches
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches
pytest.skip("MXFP8 unsupported for grouped linear.")

config = model_configs[model]
Expand Down Expand Up @@ -1648,9 +1648,9 @@ def test_padding_grouped_linear_accuracy(
):
if fp8 and not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.block() and not mxfp8_available:
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)
if fp8 and recipe.block(): # TODO(ksivamani): debug mismatches
if fp8 and recipe.mxfp8(): # TODO(ksivamani): debug mismatches
pytest.skip("MXFP8 unsupported for grouped linear.")

config = model_configs[model]
Expand Down Expand Up @@ -1860,7 +1860,7 @@ def _test_gpt_fp8_parameters(bs, dtype, config, fp8_model_params, recipe):
def test_gpt_fp8_parameters(dtype, bs, model, recipe):
if not fp8_available:
pytest.skip(reason_for_no_fp8)
if recipe.block() and not mxfp8_available:
if recipe.mxfp8() and not mxfp8_available:
pytest.skip(reason_for_no_mxfp8)

config = model_configs[model]
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/common/recipe/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ class Recipe:
Base recipe class.
"""

def block(self):
"""Whether the given recipe is block scaling."""
return isinstance(self, BlockScaling)
def mxfp8(self):
"""Whether the given recipe is MXFP8 block scaling."""
return isinstance(self, MXFP8BlockScaling)

def delayed(self):
"""Whether the given recipe is delayed scaling."""
Expand Down Expand Up @@ -162,7 +162,7 @@ def __repr__(self) -> str:


@dataclass()
class BlockScaling(Recipe):
class MXFP8BlockScaling(Recipe):
"""
Use the current scaling factor strategy.
Expand Down
24 changes: 12 additions & 12 deletions transformer_engine/pytorch/fp8.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@

import torch
import transformer_engine_torch as tex
from transformer_engine.common.recipe import Recipe, DelayedScaling, Format, BlockScaling
from transformer_engine.common.recipe import Recipe, DelayedScaling, Format, MXFP8BlockScaling

from .constants import dist_group_type
from .utils import get_device_compute_capability
Expand Down Expand Up @@ -46,7 +46,7 @@ def check_mxfp8_support() -> Tuple[bool, str]:
def get_default_fp8_recipe() -> Recipe:
"""FP8 recipe with default args."""
if get_device_compute_capability() >= (10, 0): # blackwell and above
return BlockScaling()
return MXFP8BlockScaling()
return DelayedScaling()


Expand Down Expand Up @@ -211,7 +211,7 @@ def add_fp8_tensors_to_global_buffer(
wrapper. For non CG case, it's called from within the module.
"""

if fp8_meta["recipe"].block():
if fp8_meta["recipe"].mxfp8():
return

# Every module must call this function exactly once since
Expand Down Expand Up @@ -414,7 +414,7 @@ def fp8_autocast_enter(
if enabled:
fp8_available, reason_for_no_fp8 = cls.is_fp8_available()
assert fp8_available, reason_for_no_fp8
if isinstance(fp8_recipe, BlockScaling):
if isinstance(fp8_recipe, MXFP8BlockScaling):
mxfp8_available, reason_for_no_mxfp8 = cls.is_mxfp8_available()
assert mxfp8_available, reason_for_no_mxfp8

Expand All @@ -434,7 +434,7 @@ def copy_forward_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -
to ensure both forward steps are numerically same.
"""

if fp8_meta["recipe"].block():
if fp8_meta["recipe"].mxfp8():
return

buffer_position_key = "global_fp8_buffer_pos_fwd_recompute"
Expand All @@ -460,7 +460,7 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non
1 forward for indentical numerical outputs.
"""

if fp8_meta["recipe"].block():
if fp8_meta["recipe"].mxfp8():
return

# Store updated amaxes and scales from phase 1 post forward.
Expand All @@ -479,7 +479,7 @@ def get_old_fp8_meta_tensors_for_recompute(cls, fp8_meta: Dict[str, Any]) -> Non
def restore_fp8_meta_tensors(fp8_meta: Dict[str, Any]) -> None:
"""Restore latest scaling factors and amaxes after recompute forward run."""

if fp8_meta["recipe"].block():
if fp8_meta["recipe"].mxfp8():
return

fp8_meta["scaling_fwd"].amax_history.copy_(fp8_meta["updated_amax_history_fwd"])
Expand Down Expand Up @@ -741,8 +741,8 @@ def create(
cls = None
if recipe.delayed():
cls = DelayedScalingRecipeState
elif recipe.block():
cls = BlockScalingRecipeState
elif recipe.mxfp8():
cls = MXFP8BlockScalingRecipeState
else:
raise ValueError("{recipe.__class__.__name__} is not supported")
return cls(
Expand Down Expand Up @@ -813,20 +813,20 @@ def make_quantizers(self) -> list:
]


class BlockScalingRecipeState(RecipeState):
class MXFP8BlockScalingRecipeState(RecipeState):
"""Configuration for MXFP8 quantization.
MXFP8 quantization does not require state.
"""

recipe: BlockScaling
recipe: MXFP8BlockScaling
mode: str
dtype: tex.DType

def __init__(
self,
recipe: BlockScaling,
recipe: MXFP8BlockScaling,
*,
mode: str,
num_quantizers: int = 1,
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@

from ._common import _ParameterInitMeta
from ..fp8 import (
BlockScalingRecipeState,
MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState,
FP8GlobalStateManager,
RecipeState,
Expand Down Expand Up @@ -540,7 +540,7 @@ def set_meta_tensor(self, fwd: bool, recipe: Recipe) -> None:
if recipe.delayed() and isinstance(recipe_state, DelayedScalingRecipeState):
self.adjust_amax_history_length(recipe.amax_history_len, fwd=fwd)
return
if recipe.block() and isinstance(recipe_state, BlockScalingRecipeState):
if recipe.mxfp8() and isinstance(recipe_state, MXFP8BlockScalingRecipeState):
return

# Max. number of fp8 tensors per GEMM = 3 (input, weight, output) for fwd and
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/pytorch/module/grouped_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,7 +86,7 @@ def forward(
device = inp.device

# TODO Support MXFP8 # pylint: disable=fixme
if fp8 and FP8GlobalStateManager.get_fp8_recipe().block():
if fp8 and FP8GlobalStateManager.get_fp8_recipe().mxfp8():
raise NotImplementedError("GroupedLinear does not yet support MXFP8")

# Make sure input dimensions are compatible
Expand Down
6 changes: 3 additions & 3 deletions transformer_engine/pytorch/ops/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from transformer_engine.common.recipe import Recipe
from ..fp8 import (
BlockScalingRecipeState,
MXFP8BlockScalingRecipeState,
DelayedScalingRecipeState,
FP8GlobalStateManager,
RecipeState,
Expand Down Expand Up @@ -260,7 +260,7 @@ def _update_quantization_recipe_state(
recipe_state = self._fp8_metas[mode][fp8_meta_key]
need_to_reset_recipe_state = (
recipe.delayed() and not isinstance(recipe_state, DelayedScalingRecipeState)
) or (recipe.block() and not isinstance(recipe_state, BlockScalingRecipeState))
) or (recipe.mxfp8() and not isinstance(recipe_state, MXFP8BlockScalingRecipeState))
if need_to_reset_recipe_state:
self._reset_quantization_recipe_state(recipe=recipe)
return
Expand All @@ -283,7 +283,7 @@ def _update_quantization_recipe_state(
recipe_state = fp8_meta[fp8_meta_key]

# Reallocate amax history if needed
if recipe.block():
if recipe.mxfp8():
continue

current_length = recipe_state.amax_history.size(0)
Expand Down

0 comments on commit 5955f7e

Please sign in to comment.