-
Notifications
You must be signed in to change notification settings - Fork 360
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[JAX][Common] Support GQA #578
Conversation
/te-ci |
0f641e6
to
000548c
Compare
/te-ci jax |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Please fix CI. Looks good to me. Thanks!
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu
Outdated
Show resolved
Hide resolved
000548c
to
5372c5c
Compare
/te-ci |
26fa707
to
13b4cc0
Compare
/te-ci |
1 similar comment
/te-ci |
b538b4a
to
13b4cc0
Compare
/te-ci |
Signed-off-by: Reese Wang <rewang@nvidia.com>
Signed-off-by: Reese Wang <rewang@nvidia.com>
Signed-off-by: Reese Wang <rewang@nvidia.com>
Signed-off-by: Reese Wang <rewang@nvidia.com>
Signed-off-by: Reese Wang <rewang@nvidia.com>
Signed-off-by: Reese Wang <rewang@nvidia.com>
Signed-off-by: Reese Wang <rewang@nvidia.com>
Signed-off-by: Reese Wang <rewang@nvidia.com>
Signed-off-by: Reese Wang <rewang@nvidia.com>
Signed-off-by: Reese Wang <rewang@nvidia.com>
Signed-off-by: Reese Wang <rewang@nvidia.com>
Signed-off-by: Reese Wang <rewang@nvidia.com>
Signed-off-by: Reese Wang <rewang@nvidia.com>
acc308a
to
bbe7066
Compare
/te-ci |
Signed-off-by: Reese Wang <rewang@nvidia.com>
/te-ci |
@cyanguwa @denera @mingxu1067, all unit tests passed. Could you help review again? Thanks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
* Support num_gqa_groups arguments Signed-off-by: Reese Wang <rewang@nvidia.com> * Add GQA support on the JAX bridge code Signed-off-by: Reese Wang <rewang@nvidia.com> * Fix the kv stride of the arbitrary backend Signed-off-by: Reese Wang <rewang@nvidia.com> * Complete rewrite fused attention tests and add GQA coverage Signed-off-by: Reese Wang <rewang@nvidia.com> * Support unfused GQA Signed-off-by: Reese Wang <rewang@nvidia.com> * Calculate seqlen before the primitive for the better perf Signed-off-by: Reese Wang <rewang@nvidia.com> * Add GQA layer tests Signed-off-by: Reese Wang <rewang@nvidia.com> * Apply code style checks for te_jax Signed-off-by: Reese Wang <rewang@nvidia.com> * Apply code style checks for tests Signed-off-by: Reese Wang <rewang@nvidia.com> * Add num_gqa_groups doc Signed-off-by: Reese Wang <rewang@nvidia.com> * Refine the qkv_type Signed-off-by: Reese Wang <rewang@nvidia.com> * Correct the variable naming Signed-off-by: Reese Wang <rewang@nvidia.com> * Handle Max512 CAUSAL Signed-off-by: Reese Wang <rewang@nvidia.com> * Add WAR for the latest jax image Signed-off-by: Reese Wang <rewang@nvidia.com> --------- Signed-off-by: Reese Wang <rewang@nvidia.com>
num_gqa_groups
) for both fused attention and unfused attention implementation.kv_stride
of the flash attention