Skip to content

Commit

Permalink
Add NVTX ranges to attention
Browse files Browse the repository at this point in the history
Signed-off-by: Jaemin Choi <jaeminc@nvidia.com>
  • Loading branch information
Jaemin Choi committed Feb 1, 2025
1 parent ee3d302 commit 9e5abb5
Showing 1 changed file with 23 additions and 1 deletion.
24 changes: 23 additions & 1 deletion transformer_engine/pytorch/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,11 @@

import transformer_engine_torch as tex
import transformer_engine as te
from transformer_engine.pytorch.utils import get_cudnn_version
from transformer_engine.pytorch.utils import (
get_cudnn_version,
nvtx_range_push,
nvtx_range_pop,
)
from transformer_engine.pytorch.cpp_extensions.fused_attn import (
fused_attn_fwd,
fused_attn_bwd,
Expand Down Expand Up @@ -1805,6 +1809,8 @@ class AttnFuncWithCPAndKVP2P(torch.autograd.Function):
and `USP <https://arxiv.org/abs/2405.07719>`_.
"""

nvtx_label = "transformer_engine.attention.cp_and_kv_p2p"

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -1834,6 +1840,7 @@ def forward(
quantizers,
):
# pylint: disable=missing-function-docstring
nvtx_range_push(f"{AttnFuncWithCPAndKVP2P.nvtx_label}.forward")
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)

Expand Down Expand Up @@ -2756,12 +2763,14 @@ def forward(
ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
nvtx_range_pop()

return out_ret

@staticmethod
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
nvtx_range_push(f"{AttnFuncWithCPAndKVP2P.nvtx_label}.backward")
cp_size_a2a = ctx.cp_size_a2a
rank_a2a = ctx.rank_a2a

Expand Down Expand Up @@ -3602,6 +3611,7 @@ def backward(ctx, dout):
dq = ctx.dQKV_quantizer.create_tensor_from_data(dq, ctx.qkv_dtype)
dk = ctx.dQKV_quantizer.create_tensor_from_data(dk, ctx.qkv_dtype)
dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, ctx.qkv_dtype)
nvtx_range_pop()

return (
None,
Expand Down Expand Up @@ -3664,6 +3674,8 @@ class AttnFuncWithCPAndKVAllGather(torch.autograd.Function):
Refer section 3.3.2 of `The Llama 3 Herd of Models <https://arxiv.org/abs/2407.21783>`_.
"""

nvtx_label = "transformer_engine.attention.cp_and_kv_allgather"

@staticmethod
def forward(
ctx,
Expand All @@ -3688,6 +3700,7 @@ def forward(
cp_stream,
):
# pylint: disable=missing-function-docstring
nvtx_range_push(f"{AttnFuncWithCPAndKVAllGather.nvtx_label}.forward")
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)

Expand Down Expand Up @@ -3904,11 +3917,13 @@ def forward(
ctx.attn_mask_type = attn_mask_type
ctx.deterministic = deterministic
ctx.use_fused_attention = use_fused_attention
nvtx_range_pop()
return out

@staticmethod
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
nvtx_range_push(f"{AttnFuncWithCPAndKVAllGather.nvtx_label}.backward")
cp_size = get_distributed_world_size(ctx.cp_group)
rank = get_distributed_rank(ctx.cp_group)

Expand Down Expand Up @@ -4092,6 +4107,7 @@ def backward(ctx, dout):
dq = dq.view(*dq.shape[:seq_dim], -1, *dq.shape[(seq_dim + 2) :])
dk = dk.movedim(0, seq_dim).contiguous()
dv = dv.movedim(0, seq_dim).contiguous()
nvtx_range_pop()

return (
None,
Expand Down Expand Up @@ -4122,6 +4138,8 @@ class AttnFuncWithCPAndQKVOA2A(torch.autograd.Function):
Refer the paper `DeepSpeed Ulysses <https://arxiv.org/abs/2309.14509>`_.
"""

nvtx_label = "transformer_engine.attention.cp_and_qkvo_a2a"

@staticmethod
def forward(
ctx,
Expand Down Expand Up @@ -4151,6 +4169,7 @@ def forward(
quantizers,
):
# pylint: disable=missing-function-docstring
nvtx_range_push(f"{AttnFuncWithCPAndQKVOA2A.nvtx_label}.forward")
if softmax_scale is None:
softmax_scale = q.shape[-1] ** (-0.5)

Expand Down Expand Up @@ -4403,11 +4422,13 @@ def forward(
ctx.fp8_meta = fp8_meta
ctx.is_input_fp8 = is_input_fp8
ctx.is_output_fp8 = is_output_fp8
nvtx_range_pop()
return out_ret

@staticmethod
def backward(ctx, dout):
# pylint: disable=missing-function-docstring
nvtx_range_push(f"{AttnFuncWithCPAndQKVOA2A.nvtx_label}.backward")
cp_size = get_distributed_world_size(ctx.cp_group)

(
Expand Down Expand Up @@ -4592,6 +4613,7 @@ def backward(ctx, dout):
dv = ctx.dQKV_quantizer.create_tensor_from_data(dv, fake_dtype=dout_dtype)
if not ctx.is_input_fp8:
dq, dk, dv = [x.dequantize() for x in [dq, dk, dv]]
nvtx_range_pop()

return (
None,
Expand Down

0 comments on commit 9e5abb5

Please sign in to comment.