Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] upgrade context parallelism implementations #572

Merged
merged 52 commits into from
Jan 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
52 commits
Select commit Hold shift + click to select a range
b0c887b
try to use cuDNN fused attention for context parallelism
xrennvidia Oct 12, 2023
067a50c
Merge branch 'main' into xren/cp_with_fused_attn
xrennvidia Oct 18, 2023
9db666f
assert CP is only supported with NVTE_F16_arbitrary_seqlen
xrennvidia Oct 18, 2023
250ee38
Merge branch 'main' into xren/cp_with_fused_attn
xrennvidia Oct 19, 2023
901d22f
Merge branch 'main' into xren/cp_with_fused_attn
xrennvidia Oct 20, 2023
aef3e32
port fused attn api to context parallelism
xrennvidia Oct 23, 2023
8ee5adf
Merge branch 'main' into xren/cp_with_fused_attn
xrennvidia Oct 24, 2023
82f1580
add one more assert
xrennvidia Oct 24, 2023
63d6aac
assert CP does not support padded tokens
xrennvidia Oct 24, 2023
8ea88c4
add qkv_format into CP implementation
xrennvidia Oct 26, 2023
dab905e
Merge branch 'main' into xren/cp_with_fused_attn
xrennvidia Oct 26, 2023
07a0fab
remove qkv_format from CP function
xrennvidia Oct 30, 2023
c82b3b3
fix qkv_for,at
xrennvidia Oct 30, 2023
72ae65d
Merge branch 'main' into xren/cp_with_fused_attn
xrennvidia Oct 31, 2023
cd8ca2d
fix bwd error with FA v2
xrennvidia Oct 31, 2023
4cbeb25
fix bwd issue with fa v2 and cudnn fa
xrennvidia Nov 1, 2023
a4337ad
make cp implementation support non-causal masking
xrennvidia Nov 1, 2023
d00f620
bug fix
xrennvidia Nov 1, 2023
b74c6ce
merge with main
xrennvidia Nov 3, 2023
0b1d529
Merge branch 'main' into xren/cp_with_fused_attn
xrennvidia Nov 14, 2023
c6c1660
Merge branch 'main' into xren/cp_with_fused_attn
xrennvidia Nov 18, 2023
570e832
Merge branch 'main' into xren/cp_with_fused_attn
xrennvidia Nov 23, 2023
ccecbbd
Merge branch 'main' into xren/cp_with_fused_attn
xrennvidia Dec 1, 2023
5678457
Merge branch 'main' into xren/cp_with_fused_attn
xrennvidia Dec 4, 2023
d2f4771
merge with main
xrennvidia Dec 7, 2023
9f6633c
remove redundant asserts for CP
xrennvidia Dec 8, 2023
8b33e3d
minor assert information change
xrennvidia Dec 8, 2023
9696877
assert core attn bias has not been supported with CP yet
xrennvidia Dec 8, 2023
ffa311c
Merge branch 'main' into xren/cp_with_fused_attn
xrennvidia Dec 13, 2023
c7f6433
Merge branch 'main' into xren/cp_with_fused_attn
xrennvidia Dec 15, 2023
b14174e
Merge branch 'main' into xren/cp_with_fused_attn
xrennvidia Dec 18, 2023
1b07e96
make CP work with window_sizes of [-1, -1] and [-1, 0]
xrennvidia Dec 19, 2023
376782e
Merge branch 'main' into xren/cp_with_fused_attn
xrennvidia Dec 21, 2023
1679647
Merge branch 'main' into xren/cp_with_fused_attn
ksivaman Jan 2, 2024
0267130
add draft code for fa test with cp
xrennvidia Jan 3, 2024
4907c21
move fused attn test to a specific folder
xrennvidia Jan 3, 2024
cd8ee2a
merge with main
xrennvidia Jan 3, 2024
ebee297
add assert_close to flash attn cp test
xrennvidia Jan 3, 2024
c3e333e
add more tests for CP
xrennvidia Jan 4, 2024
8191797
add optional arguments for FA v2.4+
xrennvidia Jan 4, 2024
ea8b145
minor change
xrennvidia Jan 4, 2024
3de56e6
add skip condition for CP test
xrennvidia Jan 4, 2024
080744d
Merge branch 'main' into xren/cp_with_fused_attn
xrennvidia Jan 4, 2024
31b83d1
Merge branch 'main' into xren/cp_with_fused_attn
cyanguwa Jan 5, 2024
5719a2d
class and function naming fix
xrennvidia Jan 5, 2024
4bfb36c
docstring fix
xrennvidia Jan 5, 2024
ed43920
do not use fused attn if backend does not work with CP
xrennvidia Jan 6, 2024
553c80c
create a separate folder for CP test as it needs multi-GPUs
xrennvidia Jan 6, 2024
f5855bd
add attn_mask_type check in attn_forwrad_func_with_cp
xrennvidia Jan 6, 2024
efc5f6f
merge with main
xrennvidia Jan 8, 2024
f9ba250
Merge branch 'main' into xren/cp_with_fused_attn
xrennvidia Jan 8, 2024
d7fa391
code format fix
xrennvidia Jan 8, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions qa/L0_pytorch_distributed_unittest/test.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

set -e

: ${TE_PATH:=/opt/transformerengine}

pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn_with_cp.py
2 changes: 1 addition & 1 deletion qa/L0_pytorch_unittest/test.sh
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,6 @@ pip install pytest==6.2.5 onnxruntime==1.13.1
pytest -v -s $TE_PATH/tests/pytorch/test_sanity.py
PYTORCH_JIT=0 NVTE_ALLOW_NONDETERMINISTIC_ALGO=0 pytest -v -s $TE_PATH/tests/pytorch/test_numerics.py
pytest -v -s $TE_PATH/tests/pytorch/test_jit.py
pytest -v -s $TE_PATH/tests/pytorch/test_fused_attn.py
pytest -v -s $TE_PATH/tests/pytorch/fused_attn/test_fused_attn.py
NVTE_TORCH_COMPILE=0 pytest -v -s $TE_PATH/tests/pytorch/test_onnx_export.py
pytest -v -s $TE_PATH/tests/pytorch/test_float8tensor.py
134 changes: 134 additions & 0 deletions tests/pytorch/fused_attn/run_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,134 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import os, sys
import torch
import torch.distributed as dist
from transformer_engine.pytorch.attention import DotProductAttention
from test_fused_attn_with_cp import model_configs

dtypes={'fp16' : torch.float16, 'bf16' : torch.bfloat16}

def run_dpa_with_cp(dtype='bf16', model=None, qkv_format='bshd', kernel_backend='FlashAttention'):
"""Test DotProductAttention module with context parallelism"""

os.environ["NVTE_FLASH_ATTN"] = "0"
os.environ["NVTE_FUSED_ATTN"] = "0"
if kernel_backend == "FlashAttention":
os.environ["NVTE_FLASH_ATTN"] = "1"
if kernel_backend == "FusedAttention":
os.environ["NVTE_FUSED_ATTN"] = "1"

rank = int(os.getenv('RANK', '0'))
world_size = int(os.getenv('WORLD_SIZE', '1'))

if dist.is_initialized():
world_size = dist.get_world_size()
rank = dist.get_rank()
else:
device_count = torch.cuda.device_count()
device = rank % device_count
torch.cuda.set_device(device)

print(f"[INFO] world_size:{world_size}, rank:{rank}")

dist.init_process_group(backend='nccl', world_size=world_size, rank=rank)

# create flash attn comm group for CP
cp_comm_ranks = range(world_size)
assert(rank in cp_comm_ranks)
cp_comm_group = dist.new_group(cp_comm_ranks, backend='nccl')

config = model_configs[model]

assert config.attn_mask_type in ['causal', 'no_mask'], f"{config.attn_mask_type} is an unsupported attention mask type!"

# instantiate core attn module
core_attn = DotProductAttention(config.num_heads,
config.head_dim,
num_gqa_groups=config.num_gqa_groups,
attention_dropout=config.dropout_p,
qkv_format=qkv_format,
attn_mask_type=config.attn_mask_type)
core_attn = core_attn.cuda()

# create flash attn inputs
if qkv_format == "bshd":
q_input_shape = (config.batch_size, config.max_seqlen_q, config.num_heads, config.head_dim)
kv_input_shape = (config.batch_size, config.max_seqlen_kv, config.num_gqa_groups, config.head_dim)
attn_output_shape = (config.batch_size, config.max_seqlen_q, config.num_heads*config.head_dim)
elif qkv_format == "sbhd":
q_input_shape = (config.max_seqlen_q, config.batch_size, config.num_heads, config.head_dim)
kv_input_shape = (config.max_seqlen_kv, config.batch_size, config.num_gqa_groups, config.head_dim)
attn_output_shape = (config.max_seqlen_q, config.batch_size, config.num_heads*config.head_dim)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"
q = torch.randn(q_input_shape, dtype=dtypes[dtype]).cuda()
k = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
v = torch.randn(kv_input_shape, dtype=dtypes[dtype]).cuda()
dout = torch.randn(attn_output_shape, dtype=dtypes[dtype]).cuda()

# make sure all GPU ranks have same inputs
for x in [q, k, v, dout]:
dist.broadcast(x, 0, group=cp_comm_group)

# run core_attn without CP
for x in [q, k, v]:
x.requires_grad = True
out = core_attn(q, k, v)
out.backward(dout)

# run core_attn wit CP
q_, k_, v_, dout_ = [x.clone().detach() for x in [q, k, v, dout]]
seq_dim = qkv_format.index('s')
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q_, k_, v_, dout_]]
seq_idx = torch.tensor([rank, 2*world_size-rank-1], device=q_.device)
q_, k_, v_, dout_ = [x.index_select(seq_dim, seq_idx) for x in [q_, k_, v_, dout_]]
q_, k_, v_, dout_ = [x.view(*x.shape[:seq_dim], -1, *x.shape[(seq_dim+2):]) for x in [q_, k_, v_, dout_]]
q_, k_, v_ = [x.requires_grad_() for x in [q_, k_, v_]]
core_attn.set_context_parallel_group(cp_comm_group, cp_comm_ranks, torch.cuda.Stream())
out_ = core_attn(q_, k_, v_)
out_.backward(dout_)

for x in [out_, q_.grad, k_.grad, v_.grad]:
assert(torch.all(~torch.isnan(x)))
assert(torch.all(~torch.isinf(x)))

# compare results with and without CP
tols = dict(atol=5e-3, rtol=5e-3)
if dtype == 'bf16':
tols = dict(atol=2.5e-2, rtol=2.5e-2)
dq, dk, dv, out = [x.view(*x.shape[:seq_dim], 2*world_size, x.shape[seq_dim]//(2*world_size), *x.shape[(seq_dim+1):]) \
for x in [q.grad, k.grad, v.grad, out]]
dq, dk, dv, out = [x.index_select(seq_dim, seq_idx) for x in [dq, dk, dv, out]]
dq_, dk_, dv_, out_ = [x.view(*x.shape[:seq_dim], 2, x.shape[seq_dim]//2, *x.shape[(seq_dim+1):]) \
for x in [q_.grad, k_.grad, v_.grad, out_]]
if qkv_format == "bshd":
torch.testing.assert_close(out_[:, 0], out[:, 0], **tols)
torch.testing.assert_close(dq_[:, 0], dq[:, 0], **tols)
torch.testing.assert_close(dk_[:, 0], dk[:, 0], **tols)
torch.testing.assert_close(dv_[:, 0], dv[:, 0], **tols)
torch.testing.assert_close(out_[:, 1], out[:, 1], **tols)
torch.testing.assert_close(dq_[:, 1], dq[:, 1], **tols)
torch.testing.assert_close(dk_[:, 1], dk[:, 1], **tols)
torch.testing.assert_close(dv_[:, 1], dv[:, 1], **tols)
elif qkv_format == "sbhd":
torch.testing.assert_close(out_[0], out[0], **tols)
torch.testing.assert_close(dq_[0], dq[0], **tols)
torch.testing.assert_close(dk_[0], dk[0], **tols)
torch.testing.assert_close(dv_[0], dv[0], **tols)
torch.testing.assert_close(out_[1], out[1], **tols)
torch.testing.assert_close(dq_[1], dq[1], **tols)
torch.testing.assert_close(dk_[1], dk[1], **tols)
torch.testing.assert_close(dv_[1], dv[1], **tols)
else:
assert False, f"{qkv_format} is an unsupported qkv_format!"

def main(**kwargs):
run_dpa_with_cp(**kwargs)

if __name__ == "__main__":
kwargs = dict(arg.split('=') for arg in sys.argv[2:])
main(**kwargs)
59 changes: 59 additions & 0 deletions tests/pytorch/fused_attn/test_fused_attn_with_cp.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

import os
import pytest
import subprocess
from test_fused_attn import (
ModelConfig,
_is_flash_attention_2_available,
_cudnn_version,
)

model_configs = {
# test: b, h, hg, d, sq, skv, p, mask, bias
"cp_1_0": ModelConfig(1, 12, 12, 128, 16384, 16384, 0.0, "causal", "no_bias"), # MHA
"cp_1_1": ModelConfig(1, 12, 12, 128, 16384, 16384, 0.0, "no_mask", "no_bias"), # MHA
"cp_2_0": ModelConfig(1, 12, 1, 128, 16384, 16384, 0.0, "causal", "no_bias"), # GQA
"cp_2_1": ModelConfig(1, 12, 1, 128, 16384, 16384, 0.0, "no_mask", "no_bias"), # GQA
}

def get_bash_arguments(**kwargs):
args = ["python", "-m", "torch.distributed.launch", "--nproc-per-node=2"]
te_path = os.getenv("TE_PATH", "/opt/transformerengine")
script_path = os.path.join(te_path, "tests/pytorch/fused_attn/run_fused_attn_with_cp.py")
args.append(script_path)
for k, v in kwargs.items():
args.append(f"{k}={v}")
return args

@pytest.mark.skipif(not _is_flash_attention_2_available(), reason="Flash-attn 2.0+ is required.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
def test_cp_with_flash_attention(dtype, model, qkv_format):
subprocess.run(
get_bash_arguments(
dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend='FlashAttention'
),
check=True
)

@pytest.mark.skipif(_cudnn_version() < (8,9,7), reason="cuDNN 8.9.7+ is required.")
@pytest.mark.parametrize("dtype", ['bf16', 'fp16'])
@pytest.mark.parametrize("model", model_configs.keys())
@pytest.mark.parametrize("qkv_format", ['bshd', 'sbhd'])
def test_cp_with_fused_attention(dtype, model, qkv_format):
subprocess.run(
get_bash_arguments(
dtype=dtype,
model=model,
qkv_format=qkv_format,
kernel_backend='FusedAttention'
),
check=True
)
Loading