Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add flex attention backend #203

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion src/natten/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@
"disable_gemm_na",
"enable_tiled_na",
"disable_tiled_na",
"use_flex_attention",
]

__version__ = "0.17.4"
__version__ = "0.17.5.dev0"
16 changes: 15 additions & 1 deletion src/natten/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,13 +46,15 @@ class NattenContext:
is_deterministic_mode_enabled: bool = False
is_fused_na_enabled: bool = False
is_kv_parallelism_enabled: bool = False
use_flex_attention: bool = False

training_memory_preference: MemoryUsagePreference = MemoryUsagePreference.Default

@staticmethod
def reset():
NattenContext.is_deterministic_mode_enabled = False
NattenContext.is_fused_na_enabled = False
NattenContext.use_flex_attention = False
NattenContext.is_kv_parallelism_enabled = False
NattenContext.training_memory_preference = MemoryUsagePreference.Default

Expand Down Expand Up @@ -133,9 +135,12 @@ def is_kv_parallelism_in_fused_na_enabled() -> bool:
return NattenContext.is_kv_parallelism_enabled


def use_fused_na(mode: bool = True, kv_parallel: bool = True):
def use_fused_na(
mode: bool = True, kv_parallel: bool = True, use_flex_attention: bool = False
):
if not mode:
NattenContext.is_fused_na_enabled = False
NattenContext.use_flex_attention = False
use_kv_parallelism_in_fused_na(False)
return

Expand All @@ -147,12 +152,21 @@ def use_fused_na(mode: bool = True, kv_parallel: bool = True):
)
use_kv_parallelism_in_fused_na(kv_parallel)
NattenContext.is_fused_na_enabled = True
NattenContext.use_flex_attention = use_flex_attention


def is_fused_na_enabled() -> bool:
return NattenContext.is_fused_na_enabled


def should_use_flex_attention() -> bool:
return NattenContext.use_flex_attention


def use_flex_attention() -> bool:
return use_fused_na(mode=True, use_flex_attention=True)


use_fna = use_fused_na
is_fna_enabled = is_fused_na_enabled

Expand Down
202 changes: 202 additions & 0 deletions src/natten/flex.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,202 @@
import functools
from typing import Optional, Tuple

import torch
from torch import BoolTensor, IntTensor, Tensor
from torch.nn.attention.flex_attention import create_block_mask, flex_attention

from .types import (
CausalArg1DTypeOrDed,
CausalArg2DTypeOrDed,
CausalArg3DTypeOrDed,
Dimension1DTypeOrDed,
Dimension2DTypeOrDed,
Dimension3DTypeOrDed,
)
from .utils import check_all_args


@functools.lru_cache(maxsize=1)
def get_flex_attention_compiled():
return torch.compile(flex_attention, dynamic=False)


@functools.lru_cache(maxsize=None)
def get_block_mask(
num_dimension: int,
image_shape: Tuple[int],
kernel_size: Tuple[int],
dilation: Tuple[int],
is_causal: Tuple[bool],
):

def get_location_1d(idx: IntTensor) -> Tuple[IntTensor]:
return (idx,)

def get_location_2d(idx: IntTensor) -> Tuple[IntTensor, IntTensor]:
return (idx // image_shape[1], idx % image_shape[1]) # type: ignore

def get_location_3d(idx: IntTensor) -> Tuple[IntTensor, IntTensor, IntTensor]:
return (
idx // image_shape[2] // image_shape[1], # type: ignore
(idx // image_shape[2]) % image_shape[1], # type: ignore
idx % image_shape[2], # type: ignore
)

get_location = {
1: get_location_1d,
2: get_location_2d,
3: get_location_3d,
}[num_dimension]

def natten_mask_mod(
b: IntTensor, h: IntTensor, q_idx: IntTensor, kv_idx: IntTensor
) -> BoolTensor:
q_idx = get_location(q_idx) # type: ignore
kv_idx = get_location(kv_idx) # type: ignore

masks = []
for i in range(num_dimension):
dilate_kernel = kernel_size[i] * dilation[i]
if is_causal[i]:
mask = (
(q_idx[i] - kv_idx[i] >= 0)
& (q_idx[i] - kv_idx[i] < dilate_kernel)
& ((q_idx[i] % dilation[i]) == (kv_idx[i] % dilation[i]))
)
else:
kernel_center_x = q_idx[i].clamp(
(dilate_kernel - 1) // 2,
(image_shape[i] - 1) - (dilate_kernel - 1) // 2,
)
mask = ((kernel_center_x - kv_idx[i]).abs() <= dilate_kernel // 2) & (
(q_idx[i] % dilation[i]) == (kv_idx[i] % dilation[i])
)

masks.append(mask)

return functools.reduce(lambda x, y: x & y, masks) # type: ignore

seq_length = functools.reduce(lambda x, y: x * y, image_shape)
block_mask = create_block_mask(natten_mask_mod, 1, 1, seq_length, seq_length) # type: ignore
return block_mask


def flex_na1d(
query: Tensor,
key: Tensor,
value: Tensor,
kernel_size: Dimension1DTypeOrDed,
dilation: Dimension1DTypeOrDed = 1,
is_causal: Optional[CausalArg1DTypeOrDed] = False,
) -> torch.Tensor:
"""
Args:
query: (batch_size, seq_length, num_head, head_dim)
key: (batch_size, seq_length, num_head, head_dim)
value: (batch_size, seq_length, num_head, head_dim)
kernel_size: Union[int, Tuple[int]]
dilation: Union[int, Tuple[int]]
is_causal: Union[bool, Tuple[bool]]
"""

kernel_size_, dilation_, is_causal_ = check_all_args(
1, kernel_size, dilation, is_causal
)

batch_size, seq_length, num_head, head_dim = query.shape
image_shape = (seq_length,)

_query = query.transpose(1, 2)
_key = key.transpose(1, 2)
_value = value.transpose(1, 2)

block_mask = get_block_mask(1, image_shape, kernel_size_, dilation_, is_causal_)
flex_attention_compiled = get_flex_attention_compiled()
out = flex_attention_compiled(_query, _key, _value, block_mask=block_mask)

out = out.transpose(1, 2)

return out


def flex_na2d(
query: Tensor,
key: Tensor,
value: Tensor,
kernel_size: Dimension2DTypeOrDed,
dilation: Dimension2DTypeOrDed = 1,
is_causal: Optional[CausalArg2DTypeOrDed] = False,
) -> torch.Tensor:
"""
Args:
query: (batch_size, image_height, image_width, num_head, head_dim)
key: (batch_size, image_height, image_width, num_head, head_dim)
value: (batch_size, image_height, image_width, num_head, head_dim)
kernel_size: Union[int, Tuple[int, int]]
dilation: Union[int, Tuple[int, int]]
is_causal: Union[bool, Tuple[bool, bool]]
"""

kernel_size_, dilation_, is_causal_ = check_all_args(
2, kernel_size, dilation, is_causal
)

batch_size, image_height, image_width, num_head, head_dim = query.shape
seq_length = image_height * image_width
image_shape = (image_height, image_width)

_query = query.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2)
_key = key.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2)
_value = value.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2)

block_mask = get_block_mask(2, image_shape, kernel_size_, dilation_, is_causal_)
flex_attention_compiled = get_flex_attention_compiled()
out = flex_attention_compiled(_query, _key, _value, block_mask=block_mask)

out = out.transpose(1, 2).view(
batch_size, image_height, image_width, num_head, head_dim
)

return out


def flex_na3d(
query: Tensor,
key: Tensor,
value: Tensor,
kernel_size: Dimension3DTypeOrDed,
dilation: Dimension3DTypeOrDed = 1,
is_causal: Optional[CausalArg3DTypeOrDed] = False,
) -> torch.Tensor:
"""
Args:
query: (batch_size, image_height, image_width, num_head, head_dim)
key: (batch_size, image_height, image_width, num_head, head_dim)
value: (batch_size, image_height, image_width, num_head, head_dim)
kernel_size: Union[int, Tuple[int, int]]
dilation: Union[int, Tuple[int, int]]
is_causal: Union[bool, Tuple[bool, bool]]
"""

kernel_size_, dilation_, is_causal_ = check_all_args(
3, kernel_size, dilation, is_causal
)

batch_size, image_depth, image_height, image_width, num_head, head_dim = query.shape
seq_length = image_depth * image_height * image_width
image_shape = (image_depth, image_height, image_width)

_query = query.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2)
_key = key.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2)
_value = value.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2)

block_mask = get_block_mask(3, image_shape, kernel_size_, dilation_, is_causal_)
flex_attention_compiled = get_flex_attention_compiled()
out = flex_attention_compiled(_query, _key, _value, block_mask=block_mask)

out = out.transpose(1, 2).view(
batch_size, image_depth, image_height, image_width, num_head, head_dim
)

return out
71 changes: 71 additions & 0 deletions src/natten/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,8 @@
)

from .autotuner import autotune_fna
from .context import should_use_flex_attention
from .flex import flex_na1d, flex_na2d, flex_na3d
from .nested import (
na1d_av_nested,
na1d_qk_nested,
Expand Down Expand Up @@ -1721,6 +1723,29 @@ def na1d(
"Fused neighborhood attention does not support nested tensors yet."
)

if should_use_flex_attention():
if scale is not None:
raise NotImplementedError(
"Custom attention scale is not supported in the Flex Attention backend."
)
if rpb is not None:
raise NotImplementedError(
"RPB is not supported in the Flex Attention backend."
)
if additional_keys is not None or additional_values is not None:
raise NotImplementedError(
"Additional keys/values is not supported in the Flex Attention backend."
)

return flex_na1d(
query,
key,
value,
kernel_size,
dilation,
is_causal,
)

tiling_config_forward, tiling_config_backward = autotune_fna(
1, query, kernel_size, dilation, is_causal
)
Expand Down Expand Up @@ -1777,6 +1802,29 @@ def na2d(
"Fused neighborhood attention does not support nested tensors yet."
)

if should_use_flex_attention():
if scale is not None:
raise NotImplementedError(
"Custom attention scale is not supported in the Flex Attention backend."
)
if rpb is not None:
raise NotImplementedError(
"RPB is not supported in the Flex Attention backend."
)
if additional_keys is not None or additional_values is not None:
raise NotImplementedError(
"Additional keys/values is not supported in the Flex Attention backend."
)

return flex_na2d(
query,
key,
value,
kernel_size,
dilation,
is_causal,
)

tiling_config_forward, tiling_config_backward = autotune_fna(
2, query, kernel_size, dilation, is_causal
)
Expand Down Expand Up @@ -1833,6 +1881,29 @@ def na3d(
"Fused neighborhood attention does not support nested tensors yet."
)

if should_use_flex_attention():
if scale is not None:
raise NotImplementedError(
"Custom attention scale is not supported in the Flex Attention backend."
)
if rpb is not None:
raise NotImplementedError(
"RPB is not supported in the Flex Attention backend."
)
if additional_keys is not None or additional_values is not None:
raise NotImplementedError(
"Additional keys/values is not supported in the Flex Attention backend."
)

return flex_na3d(
query,
key,
value,
kernel_size,
dilation,
is_causal,
)

tiling_config_forward, tiling_config_backward = autotune_fna(
3, query, kernel_size, dilation, is_causal
)
Expand Down
Loading