Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PyTorch] upgrade context parallelism implementations #572

Merged
merged 52 commits into from
Jan 10, 2024

Conversation

xrennvidia
Copy link
Collaborator

  1. port cuDNN Flash Attn API to CP implementation
  2. support both unidirectional and bidirectional attentions
  3. make CP implementation work with window_sizes of [-1, -1] and [-1, 0]

Signed-off-by: xren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
@cyanguwa
Copy link
Collaborator

cyanguwa commented Jan 5, 2024

/te-ci pytorch

q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1], TE_DType[q.dtype],
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen,
attn_scale=softmax_scale, dropout=dropout_p,
qkv_layout="bshd_bshd_bshd", attn_mask_type="causal",
)
else:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FlashAttention and FusedAttention might not be mutually exclusive. They could be both True or both False. Maybe it's better to have another flag use_flash_attention passed in here?

Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
@cyanguwa
Copy link
Collaborator

cyanguwa commented Jan 6, 2024

/te-ci pytorch

@Infi-zc
Copy link

Infi-zc commented Jan 8, 2024

Hi @xrennvidia ! Thanks for the contributions to context parallel. I have been utilizing context parallel, and it has effectively addressed the issue of training with excessively long contexts. However, considering that input sequences with variable length subseqs (called varlen in flash_attn) could potentially enhance the efficiency of attention calculations, we are also interested in training with varlen and context parallel. May I inquire if there are any plans to support context parallel under variable length conditions?

@Infi-zc
Copy link

Infi-zc commented Jan 8, 2024

Hi @xrennvidia ! Thanks for the contributions to context parallel. I have been utilizing context parallel, and it has effectively addressed the issue of training with excessively long contexts. However, considering that input sequences with variable length subseqs (called varlen in flash_attn) could potentially enhance the efficiency of attention calculations, we are also interested in training with varlen and context parallel. May I inquire if there are any plans to support context parallel under variable length conditions?

hi @cyanguwa, I was wondering if there are any plans to support variable length in context parallelization within the transformer engine? I'm interested to know if it's on the roadmap. Thanks!

@cyanguwa
Copy link
Collaborator

cyanguwa commented Jan 8, 2024

Hi @xrennvidia ! Thanks for the contributions to context parallel. I have been utilizing context parallel, and it has effectively addressed the issue of training with excessively long contexts. However, considering that input sequences with variable length subseqs (called varlen in flash_attn) could potentially enhance the efficiency of attention calculations, we are also interested in training with varlen and context parallel. May I inquire if there are any plans to support context parallel under variable length conditions?

hi @cyanguwa, I was wondering if there are any plans to support variable length in context parallelization within the transformer engine? I'm interested to know if it's on the roadmap. Thanks!

I think there's some plan to add variable length support on the cuDNN side but I don't think there's a definitive timeline at the moment.

Signed-off-by: Xiaowei Ren <xren@nvidia.com>
@xrennvidia
Copy link
Collaborator Author

xrennvidia commented Jan 8, 2024

Hi @xrennvidia ! Thanks for the contributions to context parallel. I have been utilizing context parallel, and it has effectively addressed the issue of training with excessively long contexts. However, considering that input sequences with variable length subseqs (called varlen in flash_attn) could potentially enhance the efficiency of attention calculations, we are also interested in training with varlen and context parallel. May I inquire if there are any plans to support context parallel under variable length conditions?

Happy to know that CP is helpful for you :)

We still do no have concrete plan of supporting variable sequence length. I think you mean thd format, right? This format does not have sequence dimension, which make it difficult to do sequence partitioning. And also, if sequence length is variable across input sentences, it's also very hard to do sequence partitioning with load balancing.

Will consider this anyway, but no expected ETA.

@cyanguwa
Copy link
Collaborator

cyanguwa commented Jan 9, 2024

/te-ci pytorch

Copy link
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tentative CI passed: 79121842. LGTM.

@Infi-zc
Copy link

Infi-zc commented Jan 10, 2024

Hi @xrennvidia ! Thanks for the contributions to context parallel. I have been utilizing context parallel, and it has effectively addressed the issue of training with excessively long contexts. However, considering that input sequences with variable length subseqs (called varlen in flash_attn) could potentially enhance the efficiency of attention calculations, we are also interested in training with varlen and context parallel. May I inquire if there are any plans to support context parallel under variable length conditions?

Happy to know that CP is helpful for you :)

We still do no have concrete plan of supporting variable sequence length. I think you mean thd format, right? This format does not have sequence dimension, which make it difficult to do sequence partitioning. And also, if sequence length is variable across input sentences, it's also very hard to do sequence partitioning with load balancing.

Will consider this anyway, but no expected ETA.

Got it, thanks for the reply!

@Infi-zc
Copy link

Infi-zc commented Jan 10, 2024

Hi @xrennvidia ! Thanks for the contributions to context parallel. I have been utilizing context parallel, and it has effectively addressed the issue of training with excessively long contexts. However, considering that input sequences with variable length subseqs (called varlen in flash_attn) could potentially enhance the efficiency of attention calculations, we are also interested in training with varlen and context parallel. May I inquire if there are any plans to support context parallel under variable length conditions?

hi @cyanguwa, I was wondering if there are any plans to support variable length in context parallelization within the transformer engine? I'm interested to know if it's on the roadmap. Thanks!

I think there's some plan to add variable length support on the cuDNN side but I don't think there's a definitive timeline at the moment.

Got it, thanks for the reply!

@cyanguwa cyanguwa merged commit 94f54d7 into NVIDIA:main Jan 10, 2024
21 checks passed
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn]
q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:])
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn]
kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous()

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please explain why there only using half of the sequence from neighbor node when i <= rank? Is it correct that we missed another part kv_inputs?

@xrennvidia xrennvidia deleted the xren/cp_with_fused_attn branch January 10, 2024 20:03

if causal:
# [b, s, np, hn] -> [b, 2, s//2, np, hn]
q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]]

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @xrennvidia are you trying to load balance causal attention att(Q, K_i, V_i). If that is true, I have done this for both fwd and bwd with triton kernel, and there is no need for users to take care of it outside of kernels:

https://github.com/yiakwy-xpu-ml-framework-team/triton/blob/add_support_flash_attention_v3/python/triton/ops/flash_attention.py

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants