-
Notifications
You must be signed in to change notification settings - Fork 360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[PyTorch] upgrade context parallelism implementations #572
Conversation
xrennvidia
commented
Dec 19, 2023
- port cuDNN Flash Attn API to CP implementation
- support both unidirectional and bidirectional attentions
- make CP implementation work with window_sizes of [-1, -1] and [-1, 0]
Signed-off-by: xren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
/te-ci pytorch |
q_inputs[i%2], kv_inputs[i%2][0], kv_inputs[i%2][1], TE_DType[q.dtype], | ||
tex.NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen, | ||
attn_scale=softmax_scale, dropout=dropout_p, | ||
qkv_layout="bshd_bshd_bshd", attn_mask_type="causal", | ||
) | ||
else: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
FlashAttention and FusedAttention might not be mutually exclusive. They could be both True or both False. Maybe it's better to have another flag use_flash_attention
passed in here?
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
/te-ci pytorch |
Hi @xrennvidia ! Thanks for the contributions to context parallel. I have been utilizing context parallel, and it has effectively addressed the issue of training with excessively long contexts. However, considering that input sequences with variable length subseqs (called varlen in flash_attn) could potentially enhance the efficiency of attention calculations, we are also interested in training with varlen and context parallel. May I inquire if there are any plans to support context parallel under variable length conditions? |
hi @cyanguwa, I was wondering if there are any plans to support variable length in context parallelization within the transformer engine? I'm interested to know if it's on the roadmap. Thanks! |
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
I think there's some plan to add variable length support on the cuDNN side but I don't think there's a definitive timeline at the moment. |
Signed-off-by: Xiaowei Ren <xren@nvidia.com>
Happy to know that CP is helpful for you :) We still do no have concrete plan of supporting variable sequence length. I think you mean Will consider this anyway, but no expected ETA. |
/te-ci pytorch |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tentative CI passed: 79121842. LGTM.
Got it, thanks for the reply! |
Got it, thanks for the reply! |
# [b, 2, sq//2, np, hn] -> [b, sq, np, hn] | ||
q_inputs[i%2] = q.view(q.shape[0], -1, *q.shape[-2:]) | ||
# [2, b, 2, sk//2, np, hn] -> [2, b, sk//2, np, hn] | ||
kv_inputs[i%2] = kv_inputs[i%2][:, :, 0, ...].contiguous() |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Could you please explain why there only using half of the sequence from neighbor node when i <= rank? Is it correct that we missed another part kv_inputs?
|
||
if causal: | ||
# [b, s, np, hn] -> [b, 2, s//2, np, hn] | ||
q, k, v = [x.view(x.shape[0], 2, x.shape[1]//2, *x.shape[2:]) for x in [q, k, v]] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @xrennvidia are you trying to load balance causal attention att(Q, K_i, V_i). If that is true, I have done this for both fwd and bwd with triton kernel, and there is no need for users to take care of it outside of kernels: