Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

clean CP implementation for flash attention and cuDNN 9.6 #1387

Merged
merged 24 commits into from
Jan 8, 2025
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
057de06
make pad_between_seqs check do not consider padding at the end
xrennvidia Nov 29, 2024
4bcc1ce
change CP THD test to make it consider 0-length sequence
xrennvidia Nov 29, 2024
62af46f
minor change to flash func name
xrennvidia Nov 30, 2024
fdc83fe
only use varlen func of flash attention while qkv_format is THD
xrennvidia Dec 2, 2024
a7e14bf
try to converge code of flash and fused attentions
xrennvidia Dec 2, 2024
cc2f3bc
Merge branch 'main' into xren/cp_optim
xrennvidia Dec 4, 2024
6041e8d
Merge branch 'main' into xren/cp_optim
xrennvidia Dec 5, 2024
685b08c
Merge branch 'main' into xren/cp_optim
xrennvidia Dec 6, 2024
d69198a
Merge branch 'main' into xren/cp_optim
xrennvidia Dec 11, 2024
7157ad7
Merge branch 'main' into xren/cp_optim
xrennvidia Dec 17, 2024
2bccdd0
Merge branch 'main' into xren/cp_optim
xrennvidia Dec 20, 2024
0e62ee1
fix bwd compute with P2P
xrennvidia Dec 28, 2024
1706ec4
remove redundant out_per_step view
xrennvidia Dec 28, 2024
620f86c
enable cudnn>9.6 and THD+GQA
xrennvidia Dec 28, 2024
9ca6f8d
enable CP with FusedAttn+SWA+All_Gather
xrennvidia Dec 29, 2024
7131139
enable CP with FusedAttn+SWA+All_Gather
xrennvidia Dec 29, 2024
da50f5b
code cleaning for cu_seqlens
xrennvidia Dec 30, 2024
d9443d4
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Dec 30, 2024
09b4f4f
fix some pylint error
xrennvidia Dec 30, 2024
1b0eac9
minor import change for pylint
xrennvidia Dec 30, 2024
0dc2553
more fix for pylint
xrennvidia Dec 30, 2024
bdedcb8
fix lse_seqlen in thd out correction
xrennvidia Jan 7, 2025
2722165
Merge branch 'main' into xren/cp_optim
xrennvidia Jan 7, 2025
86109a6
Merge branch 'main' into xren/cp_optim
xrennvidia Jan 8, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 10 additions & 16 deletions tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,12 +163,10 @@ def run_dpa_with_cp(
torch.tensor([q_input_shape[0]], dtype=torch.int32),
]
).cuda()
if kernel_backend == "FlashAttention":
cu_seqlens_q = cu_seqlens_q_padded[:-1]
else:
cu_seqlens_q = torch.cat(
[torch.zeros([1], dtype=torch.int32), seqlens_q.cumsum(0, dtype=torch.int32)]
).cuda()
cu_seqlens_q = torch.clone(cu_seqlens_q_padded)
if kernel_backend == "FusedAttention":
cu_seqlens_q[1:-1] = seqlens_q.cumsum(0, dtype=torch.int32).cuda()
cu_seqlens_q[-1] = cu_seqlens_q[-2]
cu_seqlens_kv = cu_seqlens_q
cu_seqlens_kv_padded = cu_seqlens_q_padded
else:
Expand Down Expand Up @@ -204,10 +202,8 @@ def run_dpa_with_cp(
core_attention_bias=bias,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1],
cu_seqlens_kv_padded=(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)
if fp8_mha:
dout_fp8 = Float8Tensor.to_float8(dout, fp8_dtype=tex.DType.kFloat8E5M2)
Expand Down Expand Up @@ -276,10 +272,8 @@ def run_dpa_with_cp(
core_attention_bias=bias_,
cu_seqlens_q=cu_seqlens_q,
cu_seqlens_kv=cu_seqlens_kv,
cu_seqlens_q_padded=None if cu_seqlens_q_padded is None else cu_seqlens_q_padded[:-1],
cu_seqlens_kv_padded=(
None if cu_seqlens_kv_padded is None else cu_seqlens_kv_padded[:-1]
),
cu_seqlens_q_padded=cu_seqlens_q_padded,
cu_seqlens_kv_padded=cu_seqlens_kv_padded,
)
if fp8_mha:
dout_fp8_ = Float8Tensor.to_float8(dout_, fp8_dtype=tex.DType.kFloat8E5M2)
Expand Down Expand Up @@ -311,7 +305,7 @@ def run_dpa_with_cp(
dq, out = [x.index_select(0, seq_idx_q).contiguous() for x in [q.grad, out]]
dk, dv = [x.index_select(0, seq_idx_kv).contiguous() for x in [k.grad, v.grad]]
dq_, dk_, dv_, out_ = [q_.grad, k_.grad, v_.grad, out_]
cu_seqlens_q_padded = cu_seqlens_q_padded[:-1] // world_size
cu_seqlens_q_padded = cu_seqlens_q_padded // world_size
cu_seqlens_q = get_cu_seqlens_on_cp_rank(
cu_seqlens_q, cu_seqlens_q_padded, world_size, rank, True, True
)
Expand All @@ -327,7 +321,7 @@ def run_dpa_with_cp(
).item()
== 0
)
cu_seqlens_kv_padded = cu_seqlens_kv_padded[:-1] // world_size
cu_seqlens_kv_padded = cu_seqlens_kv_padded // world_size
cu_seqlens_kv = get_cu_seqlens_on_cp_rank(
cu_seqlens_kv, cu_seqlens_kv_padded, world_size, rank, True, True
)
Expand Down
10 changes: 2 additions & 8 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,22 +121,14 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("CP implementation with KV all-gather is only supported with cuDNN >= 9.3.0!")
if dtype == "fp8" and get_device_compute_capability() < (9, 0):
pytest.skip("FP8 attention is only supported on sm90+!")
if qkv_format == "thd" and get_cudnn_version() >= (9, 6, 0):
pytest.skip("THD format is not supported for cuDNN 9.6+!")

config = model_configs_fused_attn[model]
if qkv_format == "thd" and config.num_heads != config.num_gqa_groups:
pytest.skip("THD format does not support QGA/MQA yet!")
if qkv_format == "thd" and config.attn_bias_type == "post_scale_bias":
pytest.skip("THD format does not support post_scale_bias yet!")
if qkv_format == "thd" and cp_comm_type == "all_gather":
pytest.skip("CP implementation with KV all-gather does not support THD format yet!")
if qkv_format == "thd" and "a2a" in cp_comm_type:
pytest.skip("CP implementation with QKVO A2A does not support THD format yet!")
if config.window_size != (-1, 0) and config.window_size != (-1, -1) and cp_comm_type != "a2a":
pytest.skip(
"Sliding window attention only can be supported with the implementation of QKVO A2A!"
)
if dtype == "fp8" and cp_comm_type == "all_gather":
pytest.skip(
"CP implementation with KV all-gather does not support FP8 + context parallelism yet!"
Expand All @@ -147,6 +139,8 @@ def test_cp_with_fused_attention(dtype, model, qkv_format, cp_comm_type, fp8_mha
pytest.skip("FP8 attention cannot work with bias yet!")
if dtype == "fp8" and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("FP8 attention cannot work with sliding window yet!")
if "p2p" in cp_comm_type and config.window_size != (-1, 0) and config.window_size != (-1, -1):
pytest.skip("CP implementation with KV P2P does not support sliding window yet!")
if cp_comm_type == "all_gather" and config.attn_bias_type != "no_bias":
pytest.skip("CP implementation with KV all-gather does not support bias yet!")
if "a2a" in cp_comm_type and config.attn_bias_type != "no_bias":
Expand Down
Loading
Loading