Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NVTX ranges to categorize execution #1447

Merged
merged 1 commit into from
Feb 12, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@

import transformer_engine_torch as tex
import transformer_engine as te
from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.utils import (
get_cudnn_version,
nvtx_range_pop,
nvtx_range_push,
)
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd,
fused_attn_bwd,
Expand Down Expand Up @@ -1834,6 +1838,7 @@ def forward(
quantizers,
):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.forward")
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)

Expand Down Expand Up @@ -2756,12 +2761,14 @@ def forward(
ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.forward")

return out_ret

@staticmethod
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVP2P.backward")
cp_size_a2a = ctx.cp_size_a2a
rank_a2a = ctx.rank_a2a

Expand Down Expand Up @@ -3602,6 +3609,7 @@ def backward(ctx, dout):
dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, ctx.qkv_dtype)
dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, ctx.qkv_dtype)
dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, ctx.qkv_dtype)
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVP2P.backward")

return (
None,
Expand Down Expand Up @@ -3688,6 +3696,7 @@ def forward(
cp_stream,
):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)

Expand Down Expand Up @@ -3904,11 +3913,13 @@ def forward(
ctx.attn_mask_type = attn_mask_type
ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.forward")
return out

@staticmethod
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")
cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)

Expand Down Expand Up @@ -4092,6 +4103,7 @@ def backward(ctx, dout):
dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :])
dk = dk.movedim(0, seq_dim).contiguous()
dv = dv.movedim(0, seq_dim).contiguous()
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndKVAllGather.backward")

return (
None,
Expand Down Expand Up @@ -4151,6 +4163,7 @@ def forward(
quantizers,
):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)

Expand Down Expand Up @@ -4403,11 +4416,13 @@ def forward(
ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.forward")
return out_ret

@staticmethod
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
nvtx_range_push("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")
cp_size = get_distributed_world_size(ctx.cp_group)

(
Expand Down Expand Up @@ -4592,6 +4607,7 @@ def backward(ctx, dout):
dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype)
if not ctx.is_input_fp8:
dq, dk, dv = [x.dequantize() for x in [dq, dk, dv]]
nvtx_range_pop("transformer_engine.AttnFuncWithCPAndQKVOA2A.backward")

return (
None,
Expand Down
44 changes: 41 additions & 3 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,14 @@
)
from ..fp8 import FP8GlobalStateManager
from ..utils import (
assert_dim_for_fp8_exec,
cast_if_needed,
clear_tensor_data,
divide,
get_default_init_method,
init_method_constant,
cast_if_needed,
assert_dim_for_fp8_exec,
clear_tensor_data,
nvtx_range_pop,
nvtx_range_push,
requires_grad,
)
from ..distributed import (
Expand Down Expand Up @@ -112,6 +114,12 @@ def forward(
skip_fp8_weight_update: bool,
) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]:
# pylint: disable=missing-function-docstring

# NVTX label for profiling
nvtx_label = "transformer_engine._LayerNormLinear.forward"
if ub_name is not None:
nvtx_label = f"{nvtx_label}.{ub_name}"

# Make sure input dimensions are compatible
out_features, in_features = weight.shape
inp_shape = inp.shape
Expand All @@ -121,10 +129,12 @@ def forward(
assert_dim_for_fp8_exec(inputmat, weight)

# Cast for native AMP
nvtx_range_push(f"{nvtx_label}.norm_input_cast")
inputmat = cast_if_needed(inputmat, activation_dtype)
ln_weight = cast_if_needed(ln_weight, activation_dtype)
if ln_bias is not None:
ln_bias = cast_if_needed(ln_bias, activation_dtype)
nvtx_range_pop(f"{nvtx_label}.norm_input_cast")

tp_world_size = get_distributed_world_size(tp_group)
ub_overlap_ag_fprop = (
Expand Down Expand Up @@ -175,6 +185,7 @@ def forward(
)

# Apply normalization
nvtx_range_push(f"{nvtx_label}.norm")
ln_out, mu, rsigma = apply_normalization(
inputmat,
ln_out,
Expand All @@ -188,9 +199,11 @@ def forward(
zero_centered_gamma,
)
ln_out_return = ln_out if return_layernorm_output else None
nvtx_range_pop(f"{nvtx_label}.norm")

# Prepare GEMM input
# Note: Cast to expected dtype and perform tensor-parallel communication
nvtx_range_push(f"{nvtx_label}.gemm_input_cast_comm")
if with_input_all_gather and not ub_overlap_ag_fprop:
with_quantized_all_gather = fp8
if return_layernorm_output and return_layernorm_output_gathered:
Expand All @@ -217,6 +230,7 @@ def forward(
elif backward_needs_input:
ln_out.update_usage(rowwise_usage=True, columnwise_usage=True)
ln_out_total = ln_out
nvtx_range_pop(f"{nvtx_label}.gemm_input_cast_comm")

# Cast weight to expected dtype
weightmat = weight
Expand Down Expand Up @@ -275,6 +289,7 @@ def forward(
assert ub_obj.is_fp8_ubuf(), "AG overlap with FP8 GEMM inputs requires FP8 buffer."
ln_out_total = ub_obj.get_buffer(input_quantizer)

nvtx_range_push(f"{nvtx_label}.gemm")
out, *_, rs_out = general_gemm(
weightmat,
ln_out_total,
Expand All @@ -287,6 +302,8 @@ def forward(
ub_type=ub_type,
extra_output=rs_out,
)
nvtx_range_pop(f"{nvtx_label}.gemm")

if not weight.requires_grad:
if not return_layernorm_output:
ln_out = ln_out_total = None
Expand All @@ -307,6 +324,7 @@ def forward(
# Scatter intermediate/activation tensors saved for the backward pass
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves
nvtx_range_push(f"{nvtx_label}.fsdp_scatter")
ctx.fsdp_group = fsdp_group
ctx.fsdp_shapes = _fsdp_scatter_tensors(
fsdp_group,
Expand All @@ -315,6 +333,7 @@ def forward(
weightmat if quantized_weight else None,
ln_out if weight.requires_grad else None,
)
nvtx_range_pop(f"{nvtx_label}.fsdp_scatter")

tensors_to_save, tensor_objects = prepare_for_saving(
inputmat,
Expand Down Expand Up @@ -372,10 +391,12 @@ def forward(
if ub_overlap_rs_fprop:
out = rs_out
elif parallel_mode == "row":
nvtx_range_push(f"{nvtx_label}.row_parallel_comm")
if sequence_parallel:
out, _ = reduce_scatter_along_first_dim(out, tp_group)
elif tensor_parallel:
out, _ = allreduce(out, tp_group)
nvtx_range_pop(f"{nvtx_label}.row_parallel_comm")

# [*, in_features] -> [*, out_features] except first dimension changes for SP
out = out.view(-1, *inp_shape[1:-1], out_features)
Expand All @@ -394,6 +415,11 @@ def backward(
) -> Tuple[Union[torch.Tensor, None], ...]:
# pylint: disable=missing-function-docstring

# NVTX label for profiling
nvtx_label = "transformer_engine._LayerNormLinear.backward"
if ctx.ub_name is not None:
nvtx_label = f"{nvtx_label}.{ctx.ub_name}"

with torch.cuda.nvtx.range("_LayerNormLinear_backward"):
if (
ctx.fp8
Expand Down Expand Up @@ -433,6 +459,7 @@ def backward(
# Gather intermediate/activation tensors if needed
# NOTE: weight_fp8 = weight when ctx.fp8 == False and torch.disttributed.FSDP already
# shards/unshards the base weights so we don't do it ourselves
nvtx_range_push(f"{nvtx_label}.fsdp_gather")
_fsdp_gather_tensors(
ctx.fsdp_group,
ctx.fsdp_shapes,
Expand All @@ -441,6 +468,7 @@ def backward(
weight if ctx.fp8 and ctx.quantized_weight else None,
ln_out,
)
nvtx_range_pop(f"{nvtx_label}.fsdp_gather")

# For CPU offloading, we offloaded weight and weight.main_grad to different tensors,
# we need to connect them into one.
Expand Down Expand Up @@ -515,12 +543,14 @@ def backward(
if ctx.fp8:
quantizer = ctx.input_quantizer
quantizer.set_usage(rowwise=True, columnwise=True)
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_input")
ln_out_total, ln_out_total_work = gather_along_first_dim(
ln_out,
ctx.tp_group,
async_op=True,
quantizer=quantizer,
)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_input")
else:
ln_out_total = ln_out

Expand All @@ -536,6 +566,7 @@ def backward(
if ctx.grad_input_quantizer is not None:
ctx.grad_input_quantizer.set_usage(rowwise=True, columnwise=False)

nvtx_range_push(f"{nvtx_label}.dgrad_gemm")
dgrad, *_ = general_gemm(
weight,
grad_output,
Expand All @@ -551,12 +582,14 @@ def backward(
extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_dgrad,
)
nvtx_range_pop(f"{nvtx_label}.dgrad_gemm")

# Launch tensor-parallel communication
dgrad_work = None
if ctx.ub_overlap_rs_dgrad:
dgrad = rs_out
elif ctx.parallel_mode == "column" and not ctx.ub_bulk_wgrad:
nvtx_range_push(f"{nvtx_label}.column_parallel_comm_dgrad")
if ctx.sequence_parallel:
if ctx.return_layernorm_output and ctx.return_layernorm_output_gathered:
dgrad = dgrad + grad_outputs[1].view_as(dgrad)
Expand All @@ -567,6 +600,7 @@ def backward(
)
else:
dgrad, dgrad_work = allreduce(dgrad, ctx.tp_group, async_op=True)
nvtx_range_pop(f"{nvtx_label}.column_parallel_comm_dgrad")

# Compute grad weight tensor
wgrad = None
Expand Down Expand Up @@ -603,6 +637,7 @@ def backward(

# wgrad GEMM
# Note: Fuse with bgrad computation if needed
nvtx_range_push(f"{nvtx_label}.wgrad_gemm")
wgrad, grad_bias_, *_, rs_out = general_gemm(
ln_out_total,
grad_output,
Expand All @@ -621,6 +656,7 @@ def backward(
extra_output=rs_out,
bulk_overlap=ctx.ub_bulk_wgrad,
)
nvtx_range_pop(f"{nvtx_label}.wgrad_gemm")

if ctx.ub_bulk_wgrad:
if ub_obj_wgrad.is_fp8_ubuf():
Expand Down Expand Up @@ -657,6 +693,7 @@ def backward(
# Norm gradient
dgamma = None
dbeta = None
nvtx_range_push(f"{nvtx_label}.norm")
if ctx.normalization == "LayerNorm":
dgrad, dgamma, dbeta = tex.layernorm_bwd(
dgrad,
Expand All @@ -679,6 +716,7 @@ def backward(
)
dgrad = dgrad.reshape(inputmat.size())
dbeta = None
nvtx_range_pop(f"{nvtx_label}.norm")
clear_tensor_data(mu)
clear_tensor_data(rsigma)

Expand Down
Loading
Loading