From 908461ac42ad98107fbf3630d6fc5e48a1ab7046 Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Fri, 7 Feb 2025 23:49:36 +0000 Subject: [PATCH] Add NVTX ranges to categorize execution Signed-off-by: Jaemin Choi Signed-off-by: Tim Moon Co-authored-by: Jaemin Choi Co-authored-by: Tim Moon --- transformer_engine/pytorch/attention.py | 18 +++++- .../pytorch/module/layernorm_linear.py | 44 +++++++++++++- transformer_engine/pytorch/module/linear.py | 34 ++++++++++- transformer_engine/pytorch/utils.py | 60 +++++++++++++++++++ 4 files changed, 150 insertions(+), 6 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index bf6adc309c..8584431dc2 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -23,7 +23,11 @@ import transformer_engine_torch as tex import transformer_engine as te -from transformer_engine.pytorch.utils import get_cudnn_version +from transformer_engine.pytorch.utils import ( + get_cudnn_version, + nvtx_range_pop, + nvtx_range_push, +) from transformer_engine.pytorch.cpp_extensions.fused_attn import ( fused_attn_fwd, fused_attn_bwd, @@ -1834,6 +1838,7 @@ def forward( quantizers, ): # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward") if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -2756,12 +2761,14 @@ def forward( ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward") return out_ret @staticmethod def backward(ctx, dout): # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward") cp_size_a2a = ctx.cp_size_a2a rank_a2a = ctx.rank_a2a @@ -3602,6 +3609,7 @@ def backward(ctx, dout): dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, ctx.qkv_dtype) dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, ctx.qkv_dtype) dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, ctx.qkv_dtype) + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward") return ( None, @@ -3688,6 +3696,7 @@ def forward( cp_stream, ): # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -3904,11 +3913,13 @@ def forward( ctx.attn_mask_type = attn_mask_type ctx.deterministic = deterministic ctx.use_fused_attention = use_fused_attention + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward") return out @staticmethod def backward(ctx, dout): # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") cp_size = get_distributed_world_size(ctx.cp_group) rank = get_distributed_rank(ctx.cp_group) @@ -4092,6 +4103,7 @@ def backward(ctx, dout): dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :]) dk = dk.movedim(0, seq_dim).contiguous() dv = dv.movedim(0, seq_dim).contiguous() + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward") return ( None, @@ -4151,6 +4163,7 @@ def forward( quantizers, ): # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") if softmax_scale is None: softmax_scale = q.shape[-1] ** (-0.5) @@ -4403,11 +4416,13 @@ def forward( ctx.fp8_meta = fp8_meta ctx.is_input_fp8 = is_input_fp8 ctx.is_output_fp8 = is_output_fp8 + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward") return out_ret @staticmethod def backward(ctx, dout): # pylint: disable=missing-function-docstring + nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") cp_size = get_distributed_world_size(ctx.cp_group) ( @@ -4592,6 +4607,7 @@ def backward(ctx, dout): dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype) if not ctx.is_input_fp8: dq, dk, dv = [x.dequantize() for x in [dq, dk, dv]] + nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward") return ( None, diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 60c73a8d7d..d7a7f20dc4 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -24,12 +24,14 @@ ) from ..fp8 import FP8GlobalStateManager from ..utils import ( + assert_dim_for_fp8_exec, + cast_if_needed, + clear_tensor_data, divide, get_default_init_method, init_method_constant, - cast_if_needed, - assert_dim_for_fp8_exec, - clear_tensor_data, + nvtx_range_pop, + nvtx_range_push, requires_grad, ) from ..distributed import ( @@ -112,6 +114,12 @@ def forward( skip_fp8_weight_update: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # pylint: disable=missing-function-docstring + + # NVTX label for profiling + nvtx_label = "transformer_engine._LayerNormLinear.forward" + if ub_name is not None: + nvtx_label = f"{nvtx_label}.{ub_name}" + # Make sure input dimensions are compatible out_features, in_features = weight.shape inp_shape = inp.shape @@ -121,10 +129,12 @@ def forward( assert_dim_for_fp8_exec(inputmat, weight) # Cast for native AMP + nvtx_range_push(f"{nvtx_label}.norm_input_cast") inputmat = cast_if_needed(inputmat, activation_dtype) ln_weight = cast_if_needed(ln_weight, activation_dtype) if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) + nvtx_range_pop(f"{nvtx_label}.norm_input_cast") tp_world_size = get_distributed_world_size(tp_group) ub_overlap_ag_fprop = ( @@ -175,6 +185,7 @@ def forward( ) # Apply normalization + nvtx_range_push(f"{nvtx_label}.norm") ln_out, mu, rsigma = apply_normalization( inputmat, ln_out, @@ -188,9 +199,11 @@ def forward( zero_centered_gamma, ) ln_out_return = ln_out if return_layernorm_output else None + nvtx_range_pop(f"{nvtx_label}.norm") # Prepare GEMM input # Note: Cast to expected dtype and perform tensor-parallel communication + nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm") if with_input_all_gather and not ub_overlap_ag_fprop: with_quantized_all_gather = fp8 if return_layernorm_output and return_layernorm_output_gathered: @@ -217,6 +230,7 @@ def forward( elif backward_needs_input: ln_out.update_usage(rowwise_usage=True, columnwise_usage=True) ln_out_total = ln_out + nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm") # Cast weight to expected dtype weightmat = weight @@ -275,6 +289,7 @@ def forward( assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer." ln_out_total = ub_obj.get_buffer(input_quantizer) + nvtx_range_push(f"{nvtx_label}.gemm") out, *_, rs_out = general_gemm( weightmat, ln_out_total, @@ -287,6 +302,8 @@ def forward( ub_type=ub_type, extra_output=rs_out, ) + nvtx_range_pop(f"{nvtx_label}.gemm") + if not weight.requires_grad: if not return_layernorm_output: ln_out = ln_out_total = None @@ -307,6 +324,7 @@ def forward( # Scatter intermediate/activation tensors saved for the backward pass # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # shards/unshards the base weights so we don't do it ourselves + nvtx_range_push(f"{nvtx_label}.fsdp_scatter") ctx.fsdp_group = fsdp_group ctx.fsdp_shapes = _fsdp_scatter_tensors( fsdp_group, @@ -315,6 +333,7 @@ def forward( weightmat if quantized_weight else None, ln_out if weight.requires_grad else None, ) + nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") tensors_to_save, tensor_objects = prepare_for_saving( inputmat, @@ -372,10 +391,12 @@ def forward( if ub_overlap_rs_fprop: out = rs_out elif parallel_mode == "row": + nvtx_range_push(f"{nvtx_label}.row_parallel_comm") if sequence_parallel: out, _ = reduce_scatter_along_first_dim(out, tp_group) elif tensor_parallel: out, _ = allreduce(out, tp_group) + nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") # [*, in_features] -> [*, out_features] except first dimension changes for SP out = out.view(-1, *inp_shape[1:-1], out_features) @@ -394,6 +415,11 @@ def backward( ) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring + # NVTX label for profiling + nvtx_label = "transformer_engine._LayerNormLinear.backward" + if ctx.ub_name is not None: + nvtx_label = f"{nvtx_label}.{ctx.ub_name}" + with torch.cuda.nvtx.range("_LayerNormLinear_backward"): if ( ctx.fp8 @@ -433,6 +459,7 @@ def backward( # Gather intermediate/activation tensors if needed # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # shards/unshards the base weights so we don't do it ourselves + nvtx_range_push(f"{nvtx_label}.fsdp_gather") _fsdp_gather_tensors( ctx.fsdp_group, ctx.fsdp_shapes, @@ -441,6 +468,7 @@ def backward( weight if ctx.fp8 and ctx.quantized_weight else None, ln_out, ) + nvtx_range_pop(f"{nvtx_label}.fsdp_gather") # For CPU offloading, we offloaded weight and weight.main_grad to different tensors, # we need to connect them into one. @@ -515,12 +543,14 @@ def backward( if ctx.fp8: quantizer = ctx.input_quantizer quantizer.set_usage(rowwise=True, columnwise=True) + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") ln_out_total, ln_out_total_work = gather_along_first_dim( ln_out, ctx.tp_group, async_op=True, quantizer=quantizer, ) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") else: ln_out_total = ln_out @@ -536,6 +566,7 @@ def backward( if ctx.grad_input_quantizer is not None: ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) + nvtx_range_push(f"{nvtx_label}.dgrad_gemm") dgrad, *_ = general_gemm( weight, grad_output, @@ -551,12 +582,14 @@ def backward( extra_output=rs_out, bulk_overlap=ctx.ub_bulk_dgrad, ) + nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") # Launch tensor-parallel communication dgrad_work = None if ctx.ub_overlap_rs_dgrad: dgrad = rs_out elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad: + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") if ctx.sequence_parallel: if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered: dgrad = dgrad + grad_outputs[1].view_as(dgrad) @@ -567,6 +600,7 @@ def backward( ) else: dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") # Compute grad weight tensor wgrad = None @@ -603,6 +637,7 @@ def backward( # wgrad GEMM # Note: Fuse with bgrad computation if needed + nvtx_range_push(f"{nvtx_label}.wgrad_gemm") wgrad, grad_bias_, *_, rs_out = general_gemm( ln_out_total, grad_output, @@ -621,6 +656,7 @@ def backward( extra_output=rs_out, bulk_overlap=ctx.ub_bulk_wgrad, ) + nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") if ctx.ub_bulk_wgrad: if ub_obj_wgrad.is_fp8_ubuf(): @@ -657,6 +693,7 @@ def backward( # Norm gradient dgamma = None dbeta = None + nvtx_range_push(f"{nvtx_label}.norm") if ctx.normalization == "LayerNorm": dgrad, dgamma, dbeta = tex.layernorm_bwd( dgrad, @@ -679,6 +716,7 @@ def backward( ) dgrad = dgrad.reshape(inputmat.size()) dbeta = None + nvtx_range_pop(f"{nvtx_label}.norm") clear_tensor_data(mu) clear_tensor_data(rsigma) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 460ce87bc6..415cc7d9a9 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -22,12 +22,14 @@ from ._common import noop_cat, _fix_gathered_fp8_transpose from ..fp8 import FP8GlobalStateManager from ..utils import ( - divide, cast_if_needed, clear_tensor_data, + divide, init_method_constant, - requires_grad, non_tn_fp8_gemm_supported, + nvtx_range_pop, + nvtx_range_push, + requires_grad, ) from ..distributed import ( set_tensor_model_parallel_attributes, @@ -100,6 +102,11 @@ def forward( ) -> torch.Tensor: # pylint: disable=missing-function-docstring + # NVTX label for profiling + nvtx_label = "transformer_engine._Linear.forward" + if ub_name is not None: + nvtx_label = f"{nvtx_label}.{ub_name}" + # Make sure input dimensions are compatible out_features, in_features = weight.shape inp_shape = inp.shape @@ -110,6 +117,7 @@ def forward( # Prepare input tensor # Note: Cast to expected dtype and perform tensor-parallel communication + nvtx_range_push(f"{nvtx_label}.input_cast_comm") inputmat = inp inputmat_total = None with_input_all_gather_nccl = ( @@ -153,6 +161,7 @@ def forward( inputmat_total, _ = gather_along_first_dim(inputmat, tp_group) else: inputmat_total = inputmat + nvtx_range_pop(f"{nvtx_label}.input_cast_comm") # Cast weight to expected dtype weightmat = weight @@ -216,6 +225,7 @@ def forward( ub_obj.copy_into_buffer(inputmat_total, input_quantizer, local_chunk=True) inputmat_total = ub_obj.get_buffer(input_quantizer) + nvtx_range_push(f"{nvtx_label}.gemm") out, *_, rs_out = general_gemm( weightmat, inputmat_total, @@ -228,6 +238,7 @@ def forward( ub_type=ub_type, extra_output=rs_out, ) + nvtx_range_pop(f"{nvtx_label}.gemm") if is_grad_enabled: saved_inputmat = None @@ -244,12 +255,14 @@ def forward( # Scatter intermediate/activation tensors saved for the backward pass # NOTE: FSDP sharding is not valid for models initialized with primary Fp8 weights + nvtx_range_push(f"{nvtx_label}.fsdp_scatter") ctx.fsdp_group = fsdp_group ctx.fsdp_shapes = _fsdp_scatter_tensors( fsdp_group, saved_inputmat, weightmat if fp8 and not isinstance(weight, QuantizedTensor) else None, ) + nvtx_range_pop(f"{nvtx_label}.fsdp_scatter") # TODO(ksivamani): Check memory usage tensors_to_save, tensor_objects = prepare_for_saving( @@ -299,10 +312,12 @@ def forward( if ub_overlap_rs_fprop: out = rs_out elif parallel_mode == "row": + nvtx_range_push(f"{nvtx_label}.row_parallel_comm") if sequence_parallel: out, _ = reduce_scatter_along_first_dim(out, tp_group) elif tensor_parallel: out, _ = allreduce(out, tp_group) + nvtx_range_pop(f"{nvtx_label}.row_parallel_comm") out = out.view(-1, *inp_shape[1:-1], out_features) return out @@ -311,6 +326,11 @@ def forward( def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ...]: # pylint: disable=missing-function-docstring + # NVTX label for profiling + nvtx_label = "transformer_engine._Linear.backward" + if ctx.ub_name is not None: + nvtx_label = f"{nvtx_label}.{ctx.ub_name}" + with torch.cuda.nvtx.range("_Linear_backward"): if ( ctx.fp8 @@ -347,12 +367,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # Gather intermediate/activation tensors if needed # NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already # shards/unshards the base weights so we don't do it ourselves + nvtx_range_push(f"{nvtx_label}.fsdp_gather") _fsdp_gather_tensors( ctx.fsdp_group, ctx.fsdp_shapes, inputmat, weight_fp8, ) + nvtx_range_pop(f"{nvtx_label}.fsdp_gather") ctx.ub_obj_gradout = None ub_obj_dgrad = None @@ -424,12 +446,14 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], if ctx.fp8: quantizer = ctx.input_quantizer quantizer.set_usage(rowwise=True, columnwise=True) + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input") inputmat_total, inputmat_total_work = gather_along_first_dim( inputmat, ctx.tp_group, async_op=True, quantizer=quantizer, ) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input") else: inputmat_total = inputmat @@ -451,6 +475,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False) # dgrad GEMM + nvtx_range_push(f"{nvtx_label}.dgrad_gemm") dgrad, *_, rs_out = general_gemm( weight_fp8, grad_output, @@ -466,11 +491,13 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], extra_output=rs_out, bulk_overlap=ctx.ub_bulk_dgrad, ) + nvtx_range_pop(f"{nvtx_label}.dgrad_gemm") # Launch tensor-parallel communication if ctx.ub_overlap_rs_dgrad: dgrad = rs_out elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad: + nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad") if ctx.sequence_parallel: dgrad, dgrad_work = reduce_scatter_along_first_dim( dgrad, @@ -479,6 +506,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], ) else: dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True) + nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad") # Compute grad weight tensor wgrad = None @@ -515,6 +543,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], # wgrad GEMM # Note: Fuse with bgrad computation if needed + nvtx_range_push(f"{nvtx_label}.wgrad_gemm") wgrad, grad_bias_, _, rs_out = general_gemm( inputmat_total, grad_output, @@ -533,6 +562,7 @@ def backward(ctx, grad_output: torch.Tensor) -> Tuple[Union[torch.Tensor, None], extra_output=rs_out, bulk_overlap=ctx.ub_bulk_wgrad, ) + nvtx_range_pop(f"{nvtx_label}.wgrad_gemm") if ctx.ub_bulk_wgrad: if ub_obj_wgrad.is_fp8_ubuf(): diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 5b1bd82221..1922a7e867 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -6,6 +6,7 @@ from __future__ import annotations import functools import math +import os from typing import Any, Callable, List, Optional, Tuple import torch @@ -326,3 +327,62 @@ def round_up_to_nearest_multiple(value, multiple): if multiple == 0: raise ValueError("multiple cannot be zero.") return ((value + multiple - 1) // multiple) * multiple + + +@functools.lru_cache(maxsize=None) +def _nvtx_enabled() -> bool: + """Check if NVTX range profiling is enabled""" + return bool(int(os.getenv("NVTE_NVTX_ENABLED", "0"))) + + +# Messages associated with active NVTX ranges +_nvtx_range_messages: list[str] = [] + + +def nvtx_range_push(msg: str) -> None: + """Push NVTX range onto stack, if NVTX range profiling is enabled + + Set `NVTE_NVTX_ENABLED=1` in the environment to enable NVTX range + profiling. + + Parameters + ---------- + msg: str + Message to associate with range + + """ + if not _nvtx_enabled(): + return + _nvtx_range_messages.append(msg) + torch.cuda.nvtx.range_push(msg) + + +def nvtx_range_pop(msg: Optional[str] = None) -> None: + """Pop NVTX range from stack, if NVTX range profiling is enabled + + Set `NVTE_NVTX_ENABLED=1` in the environment to enable NVTX range + profiling. + + Parameters + ---------- + msg: str, optional + Message associated with range + + """ + + # Return immediately if NVTX range profiling is not enabled + if not _nvtx_enabled(): + return + + # Update list of NVTX range messages and check for consistency + if not _nvtx_range_messages: + raise RuntimeError("Attempted to pop NVTX range from empty stack") + last_msg = _nvtx_range_messages.pop() + if msg is not None and msg != last_msg: + raise ValueError( + f"Attempted to pop NVTX range from stack with msg={msg}, " + f"but last range has msg={last_msg}" + ) + + # Pop NVTX range + torch.cuda.nvtx.range_pop()