-
Notifications
You must be signed in to change notification settings - Fork 34
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add flex attention backend #203
Open
Leymore
wants to merge
4
commits into
SHI-Labs:main
Choose a base branch
from
Leymore:flex
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Conversation
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
alihassanijr
reviewed
Mar 3, 2025
alihassanijr
reviewed
Mar 3, 2025
alihassanijr
reviewed
Mar 3, 2025
alihassanijr
reviewed
Mar 3, 2025
alihassanijr
reviewed
Mar 3, 2025
alihassanijr
reviewed
Mar 3, 2025
alihassanijr
reviewed
Mar 3, 2025
leftover items:
|
Updated. It looks like the wrapper for compiled flex_attention can be removed, while the one for the flex mask cannot. The |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This PR implements NATTEN using FlexAttention.
The provided interfaces include
flex_na1d
,flex_na2d
, andflex_na3d
, which can replacena1d
,na2d
, andna3d
, but withkernel_size
,dilation
andis_casual
support only.Since
flex_attention
requires kernel compilation, the first run may take longer than usual.Usage
Bug
head_dim=64
,num_tokens<=32
, (torch.float16
ortorch.bfloat16
) andtorch.compile(flex_attention)
is used,flex_natten
produces NaN gradients during backward propagation, even though the forward results remain correct.flex_attention
kernel, but the exact trigger conditions are not fully identified. This issue will be reported to the PyTorch team.TODO