Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flex attention backend #203

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

## [Main branch]

## [0.17.5.dev0] - 2025-03-04
* Integrate flex attention backend.

## [0.17.4] - 2025-01-28
* Support for additional KV tokens in FNA (requires xFormers)
* Adds experimental support for additional KV tokens (attend to local neighborhood, and some
Expand Down
11 changes: 7 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -55,12 +55,15 @@ uninstall:
install:
@echo "Installing NATTEN from source"
NATTEN_CUDA_ARCH="${CUDA_ARCH}" \
NATTEN_N_WORKERS="${WORKERS}" \
NATTEN_WITH_CUDA="${WITH_CUDA}" \
NATTEN_VERBOSE="${VERBOSE}" \
pip install -v -e . 2>&1 | tee install.out
NATTEN_N_WORKERS="${WORKERS}" \
NATTEN_WITH_CUDA="${WITH_CUDA}" \
NATTEN_VERBOSE="${VERBOSE}" \
pip install -v -e . 2>&1 | tee install.out

test:
NATTEN_LOG_LEVEL="CRITICAL" \
PYTORCH_NO_CUDA_MEMORY_CACHING=1 \
CUBLAS_WORKSPACE_CONFIG=":4096:8" \
pytest -v -x ./tests

style:
Expand Down
18 changes: 18 additions & 0 deletions docs/frontend.md
Original file line number Diff line number Diff line change
Expand Up @@ -165,6 +165,24 @@ Future versions may offer more fine-grained control over this.

For more information, refer to [KV parallelism](fna/kv-parallelism.md).

#### Using FlexAttention Backend

NATTEN also supports FlexAttention Backend, which can be enabled as follows:

```python
from natten import use_flex_attention

use_flex_attention(True)
# Enable FlexAttention Backend

use_flex_attention(False)
# Disable FlexAttention Backend (default)
```

FlexAttention could be potentially faster than FNA on modern GPU architectures, especially for higher dimensionals (2-D or 3-D).

However, FlexAttention backend is still experimental, and may contain certain bugs due to its kernel implementation. Bug reports related to this backend in general are strongly appreciated.

### Operations
Operations are one level below our modules, and are intended to give you full control over the module-level
details, and only use the underlying neighborhood attention operators directly.
Expand Down
3 changes: 2 additions & 1 deletion src/natten/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
"disable_gemm_na",
"enable_tiled_na",
"disable_tiled_na",
"use_flex_attention",
]

__version__ = "0.17.4"
__version__ = "0.17.5.dev0"
16 changes: 15 additions & 1 deletion src/natten/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ class NattenContext:
is_deterministic_mode_enabled: bool = False
is_fused_na_enabled: bool = False
is_kv_parallelism_enabled: bool = False
use_flex_attention: bool = False

training_memory_preference: MemoryUsagePreference = MemoryUsagePreference.Default

@staticmethod
def reset():
NattenContext.is_deterministic_mode_enabled = False
NattenContext.is_fused_na_enabled = False
NattenContext.use_flex_attention = False
NattenContext.is_kv_parallelism_enabled = False
NattenContext.training_memory_preference = MemoryUsagePreference.Default

Expand Down Expand Up @@ -133,9 +135,12 @@ def is_kv_parallelism_in_fused_na_enabled() -> bool:
return NattenContext.is_kv_parallelism_enabled


def use_fused_na(mode: bool = True, kv_parallel: bool = True):
def use_fused_na(
mode: bool = True, kv_parallel: bool = True, use_flex_attention: bool = False
):
if not mode:
NattenContext.is_fused_na_enabled = False
NattenContext.use_flex_attention = False
use_kv_parallelism_in_fused_na(False)
return

Expand All @@ -147,12 +152,21 @@ def use_fused_na(mode: bool = True, kv_parallel: bool = True):
)
use_kv_parallelism_in_fused_na(kv_parallel)
NattenContext.is_fused_na_enabled = True
NattenContext.use_flex_attention = use_flex_attention


def is_fused_na_enabled() -> bool:
return NattenContext.is_fused_na_enabled


def should_use_flex_attention() -> bool:
return NattenContext.use_flex_attention


def use_flex_attention() -> bool:
return use_fused_na(mode=True, use_flex_attention=True)


use_fna = use_fused_na
is_fna_enabled = is_fused_na_enabled

Expand Down
256 changes: 256 additions & 0 deletions src/natten/flex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,256 @@
#################################################################################################
# Copyright (c) 2022-2024 Ali Hassani.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
#################################################################################################

import functools
import math
from typing import Optional, Tuple

import torch
from torch import BoolTensor, IntTensor, Tensor
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

from .types import (
CausalArg1DTypeOrDed,
CausalArg2DTypeOrDed,
CausalArg3DTypeOrDed,
CausalArgType,
Dimension1DTypeOrDed,
Dimension2DTypeOrDed,
Dimension3DTypeOrDed,
DimensionType,
)
from .utils import check_all_args


def get_flex_attention_compiled():
return torch.compile(flex_attention, dynamic=False)


@functools.lru_cache(maxsize=None)
def get_na_flex_mask(
na_dim: int,
input_size: DimensionType,
kernel_size: DimensionType,
dilation: DimensionType,
is_causal: CausalArgType,
):

def index_to_coord_1d(idx: IntTensor) -> Tuple[IntTensor]:
assert len(input_size) == 1
return (idx,)

def index_to_coord_2d(idx: IntTensor) -> Tuple[IntTensor, IntTensor]:
assert len(input_size) == 2
return (idx // input_size[1], idx % input_size[1]) # type: ignore

def index_to_coord_3d(idx: IntTensor) -> Tuple[IntTensor, IntTensor, IntTensor]:
assert len(input_size) == 3
return (
idx // input_size[2] // input_size[1], # type: ignore
(idx // input_size[2]) % input_size[1], # type: ignore
idx % input_size[2], # type: ignore
)

index_to_coord = {
1: index_to_coord_1d,
2: index_to_coord_2d,
3: index_to_coord_3d,
}[na_dim]

def na_mask_mod(
b: IntTensor, h: IntTensor, q_idx: IntTensor, kv_idx: IntTensor
) -> BoolTensor:
q_coord = index_to_coord(q_idx)
kv_coord = index_to_coord(kv_idx)

masks = []
for i in range(na_dim):
kernel_times_dilation = kernel_size[i] * dilation[i]
if is_causal[i]:
mask = (
(q_coord[i] - kv_coord[i] >= 0)
& (q_coord[i] - kv_coord[i] < kernel_times_dilation)
& ((q_coord[i] % dilation[i]) == (kv_coord[i] % dilation[i]))
)
else:
kernel_center_x = q_coord[i].clamp(
(kernel_times_dilation - 1) // 2,
(input_size[i] - 1) - (kernel_times_dilation - 1) // 2,
)
mask = (
(kernel_center_x - kv_coord[i]).abs() <= kernel_times_dilation // 2
) & ((q_coord[i] % dilation[i]) == (kv_coord[i] % dilation[i]))

masks.append(mask)

return functools.reduce(lambda x, y: x & y, masks) # type: ignore

seq_length = math.prod(input_size)
return create_block_mask(
na_mask_mod, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length, _compile=True # type: ignore
)


def flex_na1d(
query: Tensor,
key: Tensor,
value: Tensor,
kernel_size: Dimension1DTypeOrDed,
dilation: Dimension1DTypeOrDed = 1,
is_causal: Optional[CausalArg1DTypeOrDed] = False,
) -> torch.Tensor:

kernel_size_, dilation_, is_causal_ = check_all_args(
1, kernel_size, dilation, is_causal
)

if query.dim() != 4 or key.dim() != 4 or value.dim() != 4:
raise ValueError(
"flex_na1d expects query, key, and value to be 4-dimensional tensors, "
f"got {query.shape=}, {key.shape=}, {value.shape=}."
)

if query.shape != key.shape or query.shape != value.shape:
raise ValueError(
"flex_na1d expects query, key, and value to have the same shape, "
f"got {query.shape=}, {key.shape=}, {value.shape=}."
)

if query.dtype != key.dtype or query.dtype != value.dtype:
raise ValueError(
"flex_na1d expects query, key, and value to have the same dtype, "
f"got {query.dtype=}, {key.dtype=}, {value.dtype=}."
)

batch_size, seqlen, num_heads, head_dim = query.shape
input_size = (seqlen,)

query_ = query.transpose(1, 2)
key_ = key.transpose(1, 2)
value_ = value.transpose(1, 2)

na_mask = get_na_flex_mask(1, input_size, kernel_size_, dilation_, is_causal_)
flex_attention_compiled = get_flex_attention_compiled()
out_ = flex_attention_compiled(query_, key_, value_, block_mask=na_mask)

out = out_.transpose(1, 2)

return out


def flex_na2d(
query: Tensor,
key: Tensor,
value: Tensor,
kernel_size: Dimension2DTypeOrDed,
dilation: Dimension2DTypeOrDed = 1,
is_causal: Optional[CausalArg2DTypeOrDed] = False,
) -> torch.Tensor:

kernel_size_, dilation_, is_causal_ = check_all_args(
2, kernel_size, dilation, is_causal
)

if query.dim() != 5 or key.dim() != 5 or value.dim() != 5:
raise ValueError(
"flex_na2d expects query, key, and value to be 5-dimensional tensors, "
f"got {query.shape=}, {key.shape=}, {value.shape=}."
)

if query.shape != key.shape or query.shape != value.shape:
raise ValueError(
"flex_na2d expects query, key, and value to have the same shape, "
f"got {query.shape=}, {key.shape=}, {value.shape=}."
)

if query.dtype != key.dtype or query.dtype != value.dtype:
raise ValueError(
"flex_na2d expects query, key, and value to have the same dtype, "
f"got {query.dtype=}, {key.dtype=}, {value.dtype=}."
)

batch_size, seqlen_1, seqlen_2, num_heads, head_dim = query.shape
seq_length = seqlen_1 * seqlen_2
input_size = (seqlen_1, seqlen_2)

query_ = query.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2)
key_ = key.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2)
value_ = value.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2)

na_mask = get_na_flex_mask(2, input_size, kernel_size_, dilation_, is_causal_)
flex_attention_compiled = get_flex_attention_compiled()
out_ = flex_attention_compiled(query_, key_, value_, block_mask=na_mask)

out = out_.transpose(1, 2).view(batch_size, seqlen_1, seqlen_2, num_heads, head_dim)

return out


def flex_na3d(
query: Tensor,
key: Tensor,
value: Tensor,
kernel_size: Dimension3DTypeOrDed,
dilation: Dimension3DTypeOrDed = 1,
is_causal: Optional[CausalArg3DTypeOrDed] = False,
) -> torch.Tensor:

kernel_size_, dilation_, is_causal_ = check_all_args(
3, kernel_size, dilation, is_causal
)

if query.dim() != 6 or key.dim() != 6 or value.dim() != 6:
raise ValueError(
"flex_na3d expects query, key, and value to be 6-dimensional tensors, "
f"got {query.shape=}, {key.shape=}, {value.shape=}."
)

if query.shape != key.shape or query.shape != value.shape:
raise ValueError(
"flex_na3d expects query, key, and value to have the same shape, "
f"got {query.shape=}, {key.shape=}, {value.shape=}."
)

if query.dtype != key.dtype or query.dtype != value.dtype:
raise ValueError(
"flex_na3d expects query, key, and value to have the same dtype, "
f"got {query.dtype=}, {key.dtype=}, {value.dtype=}."
)

batch_size, seqlen_0, seqlen_1, seqlen_2, num_heads, head_dim = query.shape
seq_length = seqlen_0 * seqlen_1 * seqlen_2
input_size = (seqlen_0, seqlen_1, seqlen_2)

query_ = query.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2)
key_ = key.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2)
value_ = value.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2)

na_mask = get_na_flex_mask(3, input_size, kernel_size_, dilation_, is_causal_)
flex_attention_compiled = get_flex_attention_compiled()
out_ = flex_attention_compiled(query_, key_, value_, block_mask=na_mask)

out = out_.transpose(1, 2).view(
batch_size, seqlen_0, seqlen_1, seqlen_2, num_heads, head_dim
)

return out
Loading