Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flex attention backend #203

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open

Add flex attention backend #203

wants to merge 4 commits into from

Conversation

Leymore
Copy link

@Leymore Leymore commented Mar 3, 2025

This PR implements NATTEN using FlexAttention.

The provided interfaces include flex_na1d, flex_na2d, and flex_na3d, which can replace na1d, na2d, and na3d, but with kernel_size, dilation and is_casual support only.

Since flex_attention requires kernel compilation, the first run may take longer than usual.

Usage

import torch
from natten.flex import flex_na2d

batch_size, image_height, image_width, num_head, head_dim = 1, 64, 64, 8, 64
query = torch.randn(batch_size, image_height, image_width, num_head, head_dim, device='cuda')
key = torch.randn(batch_size, image_height, image_width, num_head, head_dim, device='cuda')
value = torch.randn(batch_size, image_height, image_width, num_head, head_dim, device='cuda')

output = flex_na2d(query, key, value, kernel_size=11, dilation=1, is_causal=False)

Bug

  • When head_dim=64, num_tokens<=32, (torch.float16 or torch.bfloat16) and torch.compile(flex_attention) is used, flex_natten produces NaN gradients during backward propagation, even though the forward results remain correct.
    • This is a bug in the flex_attention kernel, but the exact trigger conditions are not fully identified. This issue will be reported to the PyTorch team.

TODO

  • Optimize dilation implementation.

@alihassanijr
Copy link
Member

alihassanijr commented Mar 3, 2025

Left a few comments, mostly nits. Nvm, had to fix a few things in the unit tests, and applied those changes.

leftover items:

  • Changelog and documentation
  • Verifying whether we need python's native call cache; I think torch compile definitely has its own, and the python one just ends up being redundant.

@Leymore
Copy link
Author

Leymore commented Mar 4, 2025

Updated. It looks like the wrapper for compiled flex_attention can be removed, while the one for the flex mask cannot. The create_block_mask function is not simply compiling something, so the cache is not working.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants