From af4072b32fb178515eaad7766f06abcaf1dc319a Mon Sep 17 00:00:00 2001 From: Fengzhe Zhou Date: Mon, 3 Mar 2025 00:33:05 -0500 Subject: [PATCH 1/4] add flex attention backend --- src/natten/flex.py | 215 +++++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 215 insertions(+) create mode 100644 src/natten/flex.py diff --git a/src/natten/flex.py b/src/natten/flex.py new file mode 100644 index 0000000..62c2642 --- /dev/null +++ b/src/natten/flex.py @@ -0,0 +1,215 @@ +import functools +from typing import Any, Dict, Optional, Tuple + +import torch +from torch import BoolTensor, IntTensor, Tensor +from torch.nn.attention.flex_attention import create_block_mask, flex_attention + +from .types import ( + CausalArg1DTypeOrDed, + CausalArg2DTypeOrDed, + CausalArg3DTypeOrDed, + Dimension1DTypeOrDed, + Dimension2DTypeOrDed, + Dimension3DTypeOrDed, +) +from .utils import check_all_args + + +@functools.lru_cache(maxsize=1) +def get_flex_attention_compiled(): + return torch.compile(flex_attention, dynamic=False) + + +@functools.lru_cache(maxsize=None) +def get_block_mask( + num_dimension: int, + image_shape: Tuple[int], + kernel_size: Tuple[int], + dilation: Tuple[int], + is_causal: Tuple[bool], +): + + def get_location_1d(idx: IntTensor) -> IntTensor: + return (idx,) + + def get_location_2d(idx: IntTensor) -> Tuple[IntTensor, IntTensor]: + return (idx // image_shape[1], idx % image_shape[1]) + + def get_location_3d(idx: IntTensor) -> Tuple[IntTensor, IntTensor, IntTensor]: + return (idx // image_shape[2] // image_shape[1], (idx // image_shape[2]) % image_shape[1], idx % image_shape[2]) + + get_location = { + 1: get_location_1d, + 2: get_location_2d, + 3: get_location_3d, + }[num_dimension] + + def natten_mask_mod(b: IntTensor, h: IntTensor, q_idx: IntTensor, kv_idx: IntTensor) -> BoolTensor: + q_idx = get_location(q_idx) + kv_idx = get_location(kv_idx) + + masks = [] + for i in range(num_dimension): + dilate_kernel = kernel_size[i] * dilation[i] + if is_causal[i]: + mask = ( + (q_idx[i] - kv_idx[i] >= 0) + & (q_idx[i] - kv_idx[i] < dilate_kernel) + & ((q_idx[i] % dilation[i]) == (kv_idx[i] % dilation[i])) + ) + else: + kernel_center_x = q_idx[i].clamp( + (dilate_kernel - 1) // 2, (image_shape[i] - 1) - (dilate_kernel - 1) // 2 + ) + mask = ((kernel_center_x - kv_idx[i]).abs() <= dilate_kernel // 2) & ( + (q_idx[i] % dilation[i]) == (kv_idx[i] % dilation[i]) + ) + + masks.append(mask) + + return functools.reduce(lambda x, y: x & y, masks) + + seq_length = functools.reduce(lambda x, y: x * y, image_shape) + block_mask = create_block_mask(natten_mask_mod, 1, 1, seq_length, seq_length) + return block_mask + + +def flex_na1d( + query: Tensor, + key: Tensor, + value: Tensor, + kernel_size: Dimension1DTypeOrDed, + dilation: Dimension1DTypeOrDed = 1, + is_causal: Optional[CausalArg1DTypeOrDed] = False, + rpb: Optional[Tensor] = None, + scale: Optional[float] = None, + additional_keys: Optional[Tensor] = None, + additional_values: Optional[Tensor] = None, + xformers_kwargs: Optional[Dict] = None, +) -> torch.Tensor: + """ + Args: + query: (batch_size, seq_length, num_head, head_dim) + key: (batch_size, seq_length, num_head, head_dim) + value: (batch_size, seq_length, num_head, head_dim) + kernel_size: Union[int, Tuple[int]] + dilation: Union[int, Tuple[int]] + is_causal: Union[bool, Tuple[bool]] + """ + + kernel_size, dilation, is_causal = check_all_args(1, kernel_size, dilation, is_causal) + assert rpb is None, "rpb is not supported" + assert scale is None, "scale is not supported" + assert additional_keys is None, "additional_keys is not supported" + assert additional_values is None, "additional_values is not supported" + assert xformers_kwargs is None, "xformers_kwargs is not supported" + + batch_size, seq_length, num_head, head_dim = query.shape + image_shape = (seq_length,) + + _query = query.transpose(1, 2) + _key = key.transpose(1, 2) + _value = value.transpose(1, 2) + + block_mask = get_block_mask(1, image_shape, kernel_size, dilation, is_causal) + flex_attention_compiled = get_flex_attention_compiled() + out = flex_attention_compiled(_query, _key, _value, block_mask=block_mask) + + out = out.transpose(1, 2) + + return out + + +def flex_na2d( + query: Tensor, + key: Tensor, + value: Tensor, + kernel_size: Dimension2DTypeOrDed, + dilation: Dimension2DTypeOrDed = 1, + is_causal: Optional[CausalArg2DTypeOrDed] = False, + rpb: Optional[Tensor] = None, + scale: Optional[float] = None, + additional_keys: Optional[Tensor] = None, + additional_values: Optional[Tensor] = None, + xformers_kwargs: Optional[Dict] = None, +) -> torch.Tensor: + """ + Args: + query: (batch_size, image_height, image_width, num_head, head_dim) + key: (batch_size, image_height, image_width, num_head, head_dim) + value: (batch_size, image_height, image_width, num_head, head_dim) + kernel_size: Union[int, Tuple[int, int]] + dilation: Union[int, Tuple[int, int]] + is_causal: Union[bool, Tuple[bool, bool]] + """ + + kernel_size, dilation, is_causal = check_all_args(2, kernel_size, dilation, is_causal) + assert rpb is None, "rpb is not supported" + assert scale is None, "scale is not supported" + assert additional_keys is None, "additional_keys is not supported" + assert additional_values is None, "additional_values is not supported" + assert xformers_kwargs is None, "xformers_kwargs is not supported" + + batch_size, image_height, image_width, num_head, head_dim = query.shape + seq_length = image_height * image_width + image_shape = (image_height, image_width) + + _query = query.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2) + _key = key.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2) + _value = value.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2) + + block_mask = get_block_mask(2, image_shape, kernel_size, dilation, is_causal) + flex_attention_compiled = get_flex_attention_compiled() + out = flex_attention_compiled(_query, _key, _value, block_mask=block_mask) + + out = out.transpose(1, 2).view(batch_size, image_height, image_width, num_head, head_dim) + + return out + + +def flex_na3d( + query: Tensor, + key: Tensor, + value: Tensor, + kernel_size: Dimension3DTypeOrDed, + dilation: Dimension3DTypeOrDed = 1, + is_causal: Optional[CausalArg3DTypeOrDed] = False, + rpb: Optional[Tensor] = None, + scale: Optional[float] = None, + additional_keys: Optional[Tensor] = None, + additional_values: Optional[Tensor] = None, + xformers_kwargs: Optional[Dict] = None, +) -> torch.Tensor: + """ + Args: + query: (batch_size, image_height, image_width, num_head, head_dim) + key: (batch_size, image_height, image_width, num_head, head_dim) + value: (batch_size, image_height, image_width, num_head, head_dim) + kernel_size: Union[int, Tuple[int, int]] + dilation: Union[int, Tuple[int, int]] + is_causal: Union[bool, Tuple[bool, bool]] + """ + + kernel_size, dilation, is_causal = check_all_args(3, kernel_size, dilation, is_causal) + assert rpb is None, "rpb is not supported" + assert scale is None, "scale is not supported" + assert additional_keys is None, "additional_keys is not supported" + assert additional_values is None, "additional_values is not supported" + assert xformers_kwargs is None, "xformers_kwargs is not supported" + + batch_size, image_depth, image_height, image_width, num_head, head_dim = query.shape + seq_length = image_depth * image_height * image_width + image_shape = (image_depth, image_height, image_width) + + _query = query.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2) + _key = key.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2) + _value = value.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2) + + block_mask = get_block_mask(3, image_shape, kernel_size, dilation, is_causal) + flex_attention_compiled = get_flex_attention_compiled() + out = flex_attention_compiled(_query, _key, _value, block_mask=block_mask) + + out = out.transpose(1, 2).view(batch_size, image_depth, image_height, image_width, num_head, head_dim) + + return out From 85fdacaf17d28be0261c1313331c471399b015be Mon Sep 17 00:00:00 2001 From: Ali Hassani Date: Mon, 3 Mar 2025 15:51:45 -0500 Subject: [PATCH 2/4] Integrate flex attention backend --- src/natten/__init__.py | 3 +- src/natten/context.py | 16 +++- src/natten/flex.py | 83 ++++++++---------- src/natten/functional.py | 71 +++++++++++++++ tests/test_fna1d.py | 158 +++++++++++++++++++++++++++++++++ tests/test_fna2d.py | 177 +++++++++++++++++++++++++++++++++++++ tests/test_fna3d.py | 179 ++++++++++++++++++++++++++++++++++++++ tools/profile_1d.py | 6 +- tools/profile_2d.py | 6 +- tools/profile_3d.py | 6 +- tools/utils/formatting.py | 8 +- 11 files changed, 656 insertions(+), 57 deletions(-) diff --git a/src/natten/__init__.py b/src/natten/__init__.py index f5e6e42..bdb65b8 100644 --- a/src/natten/__init__.py +++ b/src/natten/__init__.py @@ -114,6 +114,7 @@ "disable_gemm_na", "enable_tiled_na", "disable_tiled_na", + "use_flex_attention", ] -__version__ = "0.17.4" +__version__ = "0.17.5.dev0" diff --git a/src/natten/context.py b/src/natten/context.py index fc27a35..90593d2 100644 --- a/src/natten/context.py +++ b/src/natten/context.py @@ -46,6 +46,7 @@ class NattenContext: is_deterministic_mode_enabled: bool = False is_fused_na_enabled: bool = False is_kv_parallelism_enabled: bool = False + use_flex_attention: bool = False training_memory_preference: MemoryUsagePreference = MemoryUsagePreference.Default @@ -53,6 +54,7 @@ class NattenContext: def reset(): NattenContext.is_deterministic_mode_enabled = False NattenContext.is_fused_na_enabled = False + NattenContext.use_flex_attention = False NattenContext.is_kv_parallelism_enabled = False NattenContext.training_memory_preference = MemoryUsagePreference.Default @@ -133,9 +135,12 @@ def is_kv_parallelism_in_fused_na_enabled() -> bool: return NattenContext.is_kv_parallelism_enabled -def use_fused_na(mode: bool = True, kv_parallel: bool = True): +def use_fused_na( + mode: bool = True, kv_parallel: bool = True, use_flex_attention: bool = False +): if not mode: NattenContext.is_fused_na_enabled = False + NattenContext.use_flex_attention = False use_kv_parallelism_in_fused_na(False) return @@ -147,12 +152,21 @@ def use_fused_na(mode: bool = True, kv_parallel: bool = True): ) use_kv_parallelism_in_fused_na(kv_parallel) NattenContext.is_fused_na_enabled = True + NattenContext.use_flex_attention = use_flex_attention def is_fused_na_enabled() -> bool: return NattenContext.is_fused_na_enabled +def should_use_flex_attention() -> bool: + return NattenContext.use_flex_attention + + +def use_flex_attention() -> bool: + return use_fused_na(mode=True, use_flex_attention=True) + + use_fna = use_fused_na is_fna_enabled = is_fused_na_enabled diff --git a/src/natten/flex.py b/src/natten/flex.py index 62c2642..e87141c 100644 --- a/src/natten/flex.py +++ b/src/natten/flex.py @@ -1,5 +1,5 @@ import functools -from typing import Any, Dict, Optional, Tuple +from typing import Optional, Tuple import torch from torch import BoolTensor, IntTensor, Tensor @@ -30,14 +30,18 @@ def get_block_mask( is_causal: Tuple[bool], ): - def get_location_1d(idx: IntTensor) -> IntTensor: + def get_location_1d(idx: IntTensor) -> Tuple[IntTensor]: return (idx,) def get_location_2d(idx: IntTensor) -> Tuple[IntTensor, IntTensor]: - return (idx // image_shape[1], idx % image_shape[1]) + return (idx // image_shape[1], idx % image_shape[1]) # type: ignore def get_location_3d(idx: IntTensor) -> Tuple[IntTensor, IntTensor, IntTensor]: - return (idx // image_shape[2] // image_shape[1], (idx // image_shape[2]) % image_shape[1], idx % image_shape[2]) + return ( + idx // image_shape[2] // image_shape[1], # type: ignore + (idx // image_shape[2]) % image_shape[1], # type: ignore + idx % image_shape[2], # type: ignore + ) get_location = { 1: get_location_1d, @@ -45,9 +49,11 @@ def get_location_3d(idx: IntTensor) -> Tuple[IntTensor, IntTensor, IntTensor]: 3: get_location_3d, }[num_dimension] - def natten_mask_mod(b: IntTensor, h: IntTensor, q_idx: IntTensor, kv_idx: IntTensor) -> BoolTensor: - q_idx = get_location(q_idx) - kv_idx = get_location(kv_idx) + def natten_mask_mod( + b: IntTensor, h: IntTensor, q_idx: IntTensor, kv_idx: IntTensor + ) -> BoolTensor: + q_idx = get_location(q_idx) # type: ignore + kv_idx = get_location(kv_idx) # type: ignore masks = [] for i in range(num_dimension): @@ -60,7 +66,8 @@ def natten_mask_mod(b: IntTensor, h: IntTensor, q_idx: IntTensor, kv_idx: IntTen ) else: kernel_center_x = q_idx[i].clamp( - (dilate_kernel - 1) // 2, (image_shape[i] - 1) - (dilate_kernel - 1) // 2 + (dilate_kernel - 1) // 2, + (image_shape[i] - 1) - (dilate_kernel - 1) // 2, ) mask = ((kernel_center_x - kv_idx[i]).abs() <= dilate_kernel // 2) & ( (q_idx[i] % dilation[i]) == (kv_idx[i] % dilation[i]) @@ -68,10 +75,10 @@ def natten_mask_mod(b: IntTensor, h: IntTensor, q_idx: IntTensor, kv_idx: IntTen masks.append(mask) - return functools.reduce(lambda x, y: x & y, masks) + return functools.reduce(lambda x, y: x & y, masks) # type: ignore seq_length = functools.reduce(lambda x, y: x * y, image_shape) - block_mask = create_block_mask(natten_mask_mod, 1, 1, seq_length, seq_length) + block_mask = create_block_mask(natten_mask_mod, 1, 1, seq_length, seq_length) # type: ignore return block_mask @@ -82,11 +89,6 @@ def flex_na1d( kernel_size: Dimension1DTypeOrDed, dilation: Dimension1DTypeOrDed = 1, is_causal: Optional[CausalArg1DTypeOrDed] = False, - rpb: Optional[Tensor] = None, - scale: Optional[float] = None, - additional_keys: Optional[Tensor] = None, - additional_values: Optional[Tensor] = None, - xformers_kwargs: Optional[Dict] = None, ) -> torch.Tensor: """ Args: @@ -98,12 +100,9 @@ def flex_na1d( is_causal: Union[bool, Tuple[bool]] """ - kernel_size, dilation, is_causal = check_all_args(1, kernel_size, dilation, is_causal) - assert rpb is None, "rpb is not supported" - assert scale is None, "scale is not supported" - assert additional_keys is None, "additional_keys is not supported" - assert additional_values is None, "additional_values is not supported" - assert xformers_kwargs is None, "xformers_kwargs is not supported" + kernel_size_, dilation_, is_causal_ = check_all_args( + 1, kernel_size, dilation, is_causal + ) batch_size, seq_length, num_head, head_dim = query.shape image_shape = (seq_length,) @@ -112,7 +111,7 @@ def flex_na1d( _key = key.transpose(1, 2) _value = value.transpose(1, 2) - block_mask = get_block_mask(1, image_shape, kernel_size, dilation, is_causal) + block_mask = get_block_mask(1, image_shape, kernel_size_, dilation_, is_causal_) flex_attention_compiled = get_flex_attention_compiled() out = flex_attention_compiled(_query, _key, _value, block_mask=block_mask) @@ -128,11 +127,6 @@ def flex_na2d( kernel_size: Dimension2DTypeOrDed, dilation: Dimension2DTypeOrDed = 1, is_causal: Optional[CausalArg2DTypeOrDed] = False, - rpb: Optional[Tensor] = None, - scale: Optional[float] = None, - additional_keys: Optional[Tensor] = None, - additional_values: Optional[Tensor] = None, - xformers_kwargs: Optional[Dict] = None, ) -> torch.Tensor: """ Args: @@ -144,12 +138,9 @@ def flex_na2d( is_causal: Union[bool, Tuple[bool, bool]] """ - kernel_size, dilation, is_causal = check_all_args(2, kernel_size, dilation, is_causal) - assert rpb is None, "rpb is not supported" - assert scale is None, "scale is not supported" - assert additional_keys is None, "additional_keys is not supported" - assert additional_values is None, "additional_values is not supported" - assert xformers_kwargs is None, "xformers_kwargs is not supported" + kernel_size_, dilation_, is_causal_ = check_all_args( + 2, kernel_size, dilation, is_causal + ) batch_size, image_height, image_width, num_head, head_dim = query.shape seq_length = image_height * image_width @@ -159,11 +150,13 @@ def flex_na2d( _key = key.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2) _value = value.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2) - block_mask = get_block_mask(2, image_shape, kernel_size, dilation, is_causal) + block_mask = get_block_mask(2, image_shape, kernel_size_, dilation_, is_causal_) flex_attention_compiled = get_flex_attention_compiled() out = flex_attention_compiled(_query, _key, _value, block_mask=block_mask) - out = out.transpose(1, 2).view(batch_size, image_height, image_width, num_head, head_dim) + out = out.transpose(1, 2).view( + batch_size, image_height, image_width, num_head, head_dim + ) return out @@ -175,11 +168,6 @@ def flex_na3d( kernel_size: Dimension3DTypeOrDed, dilation: Dimension3DTypeOrDed = 1, is_causal: Optional[CausalArg3DTypeOrDed] = False, - rpb: Optional[Tensor] = None, - scale: Optional[float] = None, - additional_keys: Optional[Tensor] = None, - additional_values: Optional[Tensor] = None, - xformers_kwargs: Optional[Dict] = None, ) -> torch.Tensor: """ Args: @@ -191,12 +179,9 @@ def flex_na3d( is_causal: Union[bool, Tuple[bool, bool]] """ - kernel_size, dilation, is_causal = check_all_args(3, kernel_size, dilation, is_causal) - assert rpb is None, "rpb is not supported" - assert scale is None, "scale is not supported" - assert additional_keys is None, "additional_keys is not supported" - assert additional_values is None, "additional_values is not supported" - assert xformers_kwargs is None, "xformers_kwargs is not supported" + kernel_size_, dilation_, is_causal_ = check_all_args( + 3, kernel_size, dilation, is_causal + ) batch_size, image_depth, image_height, image_width, num_head, head_dim = query.shape seq_length = image_depth * image_height * image_width @@ -206,10 +191,12 @@ def flex_na3d( _key = key.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2) _value = value.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2) - block_mask = get_block_mask(3, image_shape, kernel_size, dilation, is_causal) + block_mask = get_block_mask(3, image_shape, kernel_size_, dilation_, is_causal_) flex_attention_compiled = get_flex_attention_compiled() out = flex_attention_compiled(_query, _key, _value, block_mask=block_mask) - out = out.transpose(1, 2).view(batch_size, image_depth, image_height, image_width, num_head, head_dim) + out = out.transpose(1, 2).view( + batch_size, image_depth, image_height, image_width, num_head, head_dim + ) return out diff --git a/src/natten/functional.py b/src/natten/functional.py index 5f389dc..f45d77e 100644 --- a/src/natten/functional.py +++ b/src/natten/functional.py @@ -51,6 +51,8 @@ ) from .autotuner import autotune_fna +from .context import should_use_flex_attention +from .flex import flex_na1d, flex_na2d, flex_na3d from .nested import ( na1d_av_nested, na1d_qk_nested, @@ -1721,6 +1723,29 @@ def na1d( "Fused neighborhood attention does not support nested tensors yet." ) + if should_use_flex_attention(): + if scale is not None: + raise NotImplementedError( + "Custom attention scale is not supported in the Flex Attention backend." + ) + if rpb is not None: + raise NotImplementedError( + "RPB is not supported in the Flex Attention backend." + ) + if additional_keys is not None or additional_values is not None: + raise NotImplementedError( + "Additional keys/values is not supported in the Flex Attention backend." + ) + + return flex_na1d( + query, + key, + value, + kernel_size, + dilation, + is_causal, + ) + tiling_config_forward, tiling_config_backward = autotune_fna( 1, query, kernel_size, dilation, is_causal ) @@ -1777,6 +1802,29 @@ def na2d( "Fused neighborhood attention does not support nested tensors yet." ) + if should_use_flex_attention(): + if scale is not None: + raise NotImplementedError( + "Custom attention scale is not supported in the Flex Attention backend." + ) + if rpb is not None: + raise NotImplementedError( + "RPB is not supported in the Flex Attention backend." + ) + if additional_keys is not None or additional_values is not None: + raise NotImplementedError( + "Additional keys/values is not supported in the Flex Attention backend." + ) + + return flex_na2d( + query, + key, + value, + kernel_size, + dilation, + is_causal, + ) + tiling_config_forward, tiling_config_backward = autotune_fna( 2, query, kernel_size, dilation, is_causal ) @@ -1833,6 +1881,29 @@ def na3d( "Fused neighborhood attention does not support nested tensors yet." ) + if should_use_flex_attention(): + if scale is not None: + raise NotImplementedError( + "Custom attention scale is not supported in the Flex Attention backend." + ) + if rpb is not None: + raise NotImplementedError( + "RPB is not supported in the Flex Attention backend." + ) + if additional_keys is not None or additional_values is not None: + raise NotImplementedError( + "Additional keys/values is not supported in the Flex Attention backend." + ) + + return flex_na3d( + query, + key, + value, + kernel_size, + dilation, + is_causal, + ) + tiling_config_forward, tiling_config_backward = autotune_fna( 3, query, kernel_size, dilation, is_causal ) diff --git a/tests/test_fna1d.py b/tests/test_fna1d.py index ed33cf5..c393c47 100644 --- a/tests/test_fna1d.py +++ b/tests/test_fna1d.py @@ -33,6 +33,7 @@ use_autotuner, use_kv_parallelism_in_fused_na, ) +from natten.flex import flex_na1d from natten.functional import na1d, na1d_av, na1d_qk from natten.utils import check_all_args from natten.utils.testing import ( @@ -533,6 +534,163 @@ def test_against_sdpa(self): ) +class FlexAttentionFNA1DTest(unittest.TestCase): + def setUp(self): + _reset_everything() + + def tearDown(self): + _reset_everything() + + def _test_against_cutlass_fna( + self, B, H, L, D, kernel_size, dilation, is_causal, eps, dtype + ): + kernel_size, dilation, is_causal = check_args(kernel_size, dilation, is_causal) + with torch.no_grad(): + q, k, v, d_out = ( + torch.randn((B, L, H, D), device="cuda", dtype=dtype), + torch.randn((B, L, H, D), device="cuda", dtype=dtype), + torch.randn((B, L, H, D), device="cuda", dtype=dtype), + torch.randn((B, L, H, D), device="cuda", dtype=dtype) * 0.05, + ) + + q_ref, k_ref, v_ref, d_out_ref = ( + q.clone(), + k.clone(), + v.clone(), + d_out.clone(), + ) + + # Reference + q_ref.requires_grad_(True) + k_ref.requires_grad_(True) + v_ref.requires_grad_(True) + d_out_ref.requires_grad_(True) + out_ref_ = na1d( + q_ref, + k_ref, + v_ref, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + ) + out_ref = out_ref_.data.clone().float() + + dq_ref, dk_ref, dv_ref = None, None, None + out_ref_.backward(d_out_ref) + with torch.no_grad(): + dq_ref, dk_ref, dv_ref = ( + q_ref.grad.clone().float(), + k_ref.grad.clone().float(), + v_ref.grad.clone().float(), + ) + + # Flex + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + d_out.requires_grad_(True) + + out_ = flex_na1d( + q, + k, + v, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + ) + out = out_.data.clone().float() + + dq, dk, dv = None, None, None + out_.backward(d_out) + with torch.no_grad(): + dq, dk, dv = ( + q.grad.clone().float(), + k.grad.clone().float(), + v.grad.clone().float(), + ) + + torch.testing.assert_close(out, out_ref, atol=eps, rtol=0) + torch.testing.assert_close(dq, dq_ref, atol=eps, rtol=0) + torch.testing.assert_close(dk, dk_ref, atol=eps, rtol=0) + torch.testing.assert_close(dv, dv_ref, atol=eps, rtol=0) + + def _test_all_dtypes( + self, + B, + H, + L, + D, + kernel_size, + dilation, + is_causal=None, + ): + self._test_against_cutlass_fna( + B=B, + H=H, + L=L, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + dtype=torch.float32, + eps=1e-2, + ) + if HAS_HALF: + self._test_against_cutlass_fna( + B=B, + H=H, + L=L, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + dtype=torch.float16, + eps=1e-1, + ) + if HAS_BFLOAT: + self._test_against_cutlass_fna( + B=B, + H=H, + L=L, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + dtype=torch.bfloat16, + eps=1e-1, + ) + + @skip_if_cuda_is_not_supported() + @skip_if_fna_is_not_supported() + def test_against_cutlass_fna(self): + problem_sizes = [ + (1, 1, 3, 16, 3, 1), + (1, 1, 16, 32, 3, 1), + (1, 2, 33, 32, 15, 1), + (1, 2, 33, 64, 15, 2), + (4, 3, 256, 64, 255, 1), + (2, 2, 4096, 64, 2047, 1), + (2, 4, 4096, 64, 2047, 2), + (4, 3, 5000, 64, 511, 8), + (4, 3, 5000, 64, 255, 16), + (1, 12, 512, 64, 255, 1), + # TODO: these will fail on most non-A100/H100 cards due to the 99KB shmem limit + # (4, 24, 512, 128, 99, 1), + # (1, 48, 512, 256, 45, 4), + ] + for B, H, L, D, kernel_size, dilation in problem_sizes: + for is_causal in [False, True]: + self._test_all_dtypes( + B=B, + H=H, + L=L, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + ) + + if __name__ == "__main__": torch.manual_seed(42) unittest.main() diff --git a/tests/test_fna2d.py b/tests/test_fna2d.py index 7f25261..d46d31d 100644 --- a/tests/test_fna2d.py +++ b/tests/test_fna2d.py @@ -34,6 +34,7 @@ use_autotuner, use_kv_parallelism_in_fused_na, ) +from natten.flex import flex_na2d from natten.functional import na2d, na2d_av, na2d_qk from natten.utils import check_all_args from natten.utils.testing import ( @@ -550,6 +551,182 @@ def test_against_sdpa(self): ) +class FlexAttentionFNA2DTest(unittest.TestCase): + def setUp(self): + _reset_everything() + + def tearDown(self): + _reset_everything() + + def _test_against_cutlass_fna( + self, B, H, X, Y, D, kernel_size, dilation, is_causal, eps, dtype + ): + kernel_size, dilation, is_causal = check_args(kernel_size, dilation, is_causal) + with torch.no_grad(): + q, k, v, d_out = ( + torch.randn((B, X, Y, H, D), device="cuda", dtype=dtype), + torch.randn((B, X, Y, H, D), device="cuda", dtype=dtype), + torch.randn((B, X, Y, H, D), device="cuda", dtype=dtype), + torch.randn((B, X, Y, H, D), device="cuda", dtype=dtype) * 0.05, + ) + + q_ref, k_ref, v_ref, d_out_ref = ( + q.clone(), + k.clone(), + v.clone(), + d_out.clone(), + ) + + # Reference + q_ref.requires_grad_(True) + k_ref.requires_grad_(True) + v_ref.requires_grad_(True) + d_out_ref.requires_grad_(True) + out_ref_ = na2d( + q_ref, + k_ref, + v_ref, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + ) + out_ref = out_ref_.data.clone().float() + + dq_ref, dk_ref, dv_ref = None, None, None + out_ref_.backward(d_out_ref) + with torch.no_grad(): + dq_ref, dk_ref, dv_ref = ( + q_ref.grad.clone().float(), + k_ref.grad.clone().float(), + v_ref.grad.clone().float(), + ) + + # Flex + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + d_out.requires_grad_(True) + + out_ = flex_na2d( + q, + k, + v, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + ) + out = out_.data.clone().float() + + dq, dk, dv = None, None, None + out_.backward(d_out) + with torch.no_grad(): + dq, dk, dv = ( + q.grad.clone().float(), + k.grad.clone().float(), + v.grad.clone().float(), + ) + + torch.testing.assert_close(out, out_ref, atol=eps, rtol=0) + torch.testing.assert_close(dq, dq_ref, atol=eps, rtol=0) + torch.testing.assert_close(dk, dk_ref, atol=eps, rtol=0) + torch.testing.assert_close(dv, dv_ref, atol=eps, rtol=0) + + def _test_all_dtypes( + self, + B, + H, + X, + Y, + D, + kernel_size, + dilation, + is_causal=None, + ): + self._test_against_cutlass_fna( + B=B, + H=H, + X=X, + Y=Y, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + dtype=torch.float32, + eps=1e-2, + ) + if HAS_HALF: + self._test_against_cutlass_fna( + B=B, + H=H, + X=X, + Y=Y, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + dtype=torch.float16, + eps=1e-1, + ) + if HAS_BFLOAT: + self._test_against_cutlass_fna( + B=B, + H=H, + X=X, + Y=Y, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + dtype=torch.bfloat16, + eps=1e-1, + ) + + @skip_if_cuda_is_not_supported() + @skip_if_fna_is_not_supported() + def test_against_cutlass_fna(self): + problem_sizes = [ + (1, 1, 3, 3, 16, 3, 3, 1, 1), + (1, 1, 8, 10, 16, 3, 5, 1, 2), + (1, 2, 15, 20, 32, 5, 15, 3, 1), + (1, 2, 17, 19, 32, 9, 7, 1, 1), + (1, 2, 17, 19, 32, 9, 7, 1, 2), + (4, 3, 32, 32, 32, 31, 31, 1, 1), + (2, 2, 32, 64, 64, 25, 31, 1, 2), + (2, 4, 64, 128, 64, 55, 101, 1, 1), + (2, 4, 64, 128, 64, 21, 29, 3, 4), + # TODO: these will fail on most non-A100/H100 cards due to the 99KB shmem limit + # (4, 3, 56, 56, 128, 7, 7, 2, 4), + # (4, 3, 28, 46, 128, 11, 13, 1, 1), + (4, 3, 56, 56, 64, 7, 7, 2, 4), + (4, 3, 28, 46, 64, 11, 13, 1, 1), + ] + for ( + B, + H, + X, + Y, + D, + kernel_size_h, + kernel_size_w, + dilation_h, + dilation_w, + ) in problem_sizes: + for causal_h, causal_w in product([True, False], [True, False]): + kernel_size = (kernel_size_h, kernel_size_w) + dilation = (dilation_h, dilation_w) + is_causal = (causal_h, causal_w) + self._test_all_dtypes( + B=B, + H=H, + X=X, + Y=Y, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + ) + + if __name__ == "__main__": torch.manual_seed(42) unittest.main() diff --git a/tests/test_fna3d.py b/tests/test_fna3d.py index 18fee35..7cef58c 100644 --- a/tests/test_fna3d.py +++ b/tests/test_fna3d.py @@ -34,6 +34,7 @@ use_autotuner, use_kv_parallelism_in_fused_na, ) +from natten.flex import flex_na3d from natten.functional import na3d, na3d_av, na3d_qk from natten.utils import check_all_args from natten.utils.testing import ( @@ -559,6 +560,184 @@ def test_against_sdpa(self): ) +class FlexAttentionFNA3DTest(unittest.TestCase): + def setUp(self): + _reset_everything() + + def tearDown(self): + _reset_everything() + + def _test_against_cutlass_fna( + self, B, H, X, Y, Z, D, kernel_size, dilation, is_causal, eps, dtype + ): + kernel_size, dilation, is_causal = check_args(kernel_size, dilation, is_causal) + with torch.no_grad(): + q, k, v, d_out = ( + torch.randn((B, X, Y, Z, H, D), device="cuda", dtype=dtype), + torch.randn((B, X, Y, Z, H, D), device="cuda", dtype=dtype), + torch.randn((B, X, Y, Z, H, D), device="cuda", dtype=dtype), + torch.randn((B, X, Y, Z, H, D), device="cuda", dtype=dtype) * 0.05, + ) + + q_ref, k_ref, v_ref, d_out_ref = ( + q.clone(), + k.clone(), + v.clone(), + d_out.clone(), + ) + + # Reference + q_ref.requires_grad_(True) + k_ref.requires_grad_(True) + v_ref.requires_grad_(True) + d_out_ref.requires_grad_(True) + out_ref_ = na3d( + q_ref, + k_ref, + v_ref, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + ) + out_ref = out_ref_.data.clone().float() + + dq_ref, dk_ref, dv_ref = None, None, None + out_ref_.backward(d_out_ref) + with torch.no_grad(): + dq_ref, dk_ref, dv_ref = ( + q_ref.grad.clone().float(), + k_ref.grad.clone().float(), + v_ref.grad.clone().float(), + ) + + # Flex + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + d_out.requires_grad_(True) + + out_ = flex_na3d( + q, + k, + v, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + ) + out = out_.data.clone().float() + + dq, dk, dv = None, None, None + out_.backward(d_out) + with torch.no_grad(): + dq, dk, dv = ( + q.grad.clone().float(), + k.grad.clone().float(), + v.grad.clone().float(), + ) + + torch.testing.assert_close(out, out_ref, atol=eps, rtol=0) + torch.testing.assert_close(dq, dq_ref, atol=eps, rtol=0) + torch.testing.assert_close(dk, dk_ref, atol=eps, rtol=0) + torch.testing.assert_close(dv, dv_ref, atol=eps, rtol=0) + + def _test_all_dtypes( + self, + B, + H, + X, + Y, + Z, + D, + kernel_size, + dilation, + is_causal=None, + ): + self._test_against_cutlass_fna( + B=B, + H=H, + X=X, + Y=Y, + Z=Z, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + dtype=torch.float32, + eps=1e-2, + ) + if HAS_HALF: + self._test_against_cutlass_fna( + B=B, + H=H, + X=X, + Y=Y, + Z=Z, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + dtype=torch.float16, + eps=1e-1, + ) + if HAS_BFLOAT: + self._test_against_cutlass_fna( + B=B, + H=H, + X=X, + Y=Y, + Z=Z, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + dtype=torch.bfloat16, + eps=1e-1, + ) + + @skip_if_cuda_is_not_supported() + @skip_if_fna_is_not_supported() + def test_against_cutlass_fna(self): + problem_sizes = [ + (1, 1, 3, 3, 3, 16, 3, 3, 3, 1, 1, 1), + (1, 2, 6, 8, 12, 16, 5, 7, 11, 1, 1, 1), + (1, 4, 6, 8, 12, 32, 3, 3, 3, 2, 2, 4), + (2, 2, 6, 8, 12, 32, 3, 3, 3, 1, 1, 1), + (1, 12, 32, 8, 8, 64, 7, 5, 5, 2, 1, 1), + (4, 8, 32, 10, 10, 64, 7, 3, 3, 1, 2, 3), + ] + for ( + B, + H, + X, + Y, + Z, + D, + kernel_size_d, + kernel_size_h, + kernel_size_w, + dilation_d, + dilation_h, + dilation_w, + ) in problem_sizes: + for causal_d, causal_h, causal_w in product( + [True, False], [True, False], [True, False] + ): + kernel_size = (kernel_size_d, kernel_size_h, kernel_size_w) + dilation = (dilation_d, dilation_h, dilation_w) + is_causal = (causal_d, causal_h, causal_w) + self._test_all_dtypes( + B=B, + H=H, + X=X, + Y=Y, + Z=Z, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + ) + + if __name__ == "__main__": torch.manual_seed(42) unittest.main() diff --git a/tools/profile_1d.py b/tools/profile_1d.py index 9d7b69e..3ed384e 100644 --- a/tools/profile_1d.py +++ b/tools/profile_1d.py @@ -55,6 +55,7 @@ @click.option("--fmha", is_flag=True) @click.option("--fav2", is_flag=True) @click.option("--backprop", is_flag=True) +@click.option("--flex", is_flag=True) @click.option("--add-kv", default=0) def profile_1d( batch_size: int, @@ -76,6 +77,7 @@ def profile_1d( fav2: bool, backprop: bool, add_kv: int, + flex: bool, ): dtype = torch.float32 @@ -90,7 +92,7 @@ def profile_1d( natten.libnatten.set_gemm_tf32(False) if fuse: - natten.use_fused_na() + natten.use_fused_na(True, use_flex_attention=flex) natten.use_kv_parallelism_in_fused_na() natten.set_memory_usage_preference("unrestricted") @@ -98,6 +100,8 @@ def profile_1d( natten.use_autotuner(False, False, False, False) else: natten.use_autotuner(True, True) + elif flex: + natten.use_fused_na(True, use_flex_attention=flex) func = partial(profile_na_with_torch, fuse=fuse) if fmha: diff --git a/tools/profile_2d.py b/tools/profile_2d.py index 16686f0..2b1f072 100644 --- a/tools/profile_2d.py +++ b/tools/profile_2d.py @@ -57,6 +57,7 @@ @click.option("--fmha", is_flag=True) @click.option("--fav2", is_flag=True) @click.option("--backprop", is_flag=True) +@click.option("--flex", is_flag=True) @click.option("--add-kv", default=0) def profile_2d( batch_size: int, @@ -80,6 +81,7 @@ def profile_2d( fav2: bool, backprop: bool, add_kv: int, + flex: bool, ): dtype = torch.float32 @@ -96,7 +98,7 @@ def profile_2d( natten.libnatten.set_gemm_tf32(False) if fuse: - natten.use_fused_na() + natten.use_fused_na(True, use_flex_attention=flex) natten.use_kv_parallelism_in_fused_na() natten.set_memory_usage_preference("unrestricted") @@ -104,6 +106,8 @@ def profile_2d( natten.use_autotuner(False, False, False, False) else: natten.use_autotuner(True, True) + elif flex: + natten.use_fused_na(True, use_flex_attention=flex) func = partial(profile_na_with_torch, fuse=fuse) if fmha: diff --git a/tools/profile_3d.py b/tools/profile_3d.py index 05669d9..74d7efe 100644 --- a/tools/profile_3d.py +++ b/tools/profile_3d.py @@ -55,6 +55,7 @@ @click.option("--fmha", is_flag=True) @click.option("--fav2", is_flag=True) @click.option("--backprop", is_flag=True) +@click.option("--flex", is_flag=True) @click.option("--add-kv", default=0) def profile_3d( batch_size: int, @@ -76,6 +77,7 @@ def profile_3d( fav2: bool, backprop: bool, add_kv: int, + flex: bool, ): dtype = torch.float32 @@ -85,7 +87,7 @@ def profile_3d( dtype = torch.bfloat16 if fuse: - natten.use_fused_na() + natten.use_fused_na(True, use_flex_attention=flex) natten.use_kv_parallelism_in_fused_na() natten.set_memory_usage_preference("unrestricted") @@ -93,6 +95,8 @@ def profile_3d( natten.use_autotuner(False, False, False, False) else: natten.use_autotuner(True, True) + elif flex: + natten.use_fused_na(True, use_flex_attention=flex) func = partial(profile_na_with_torch, fuse=fuse) if fmha: diff --git a/tools/utils/formatting.py b/tools/utils/formatting.py index a79732b..c949995 100644 --- a/tools/utils/formatting.py +++ b/tools/utils/formatting.py @@ -302,14 +302,14 @@ def extract_na_ops( tags[op] = tag else: assert tags[op] == tag - logged_ops[op].append(evt.cuda_time_total) - elif evt.cuda_time_total > 0: + logged_ops[op].append(evt.device_time_total) + elif evt.device_time_total > 0: op_namespace, op_name = custom_op_to_name(evt.key) op_key = CustomOp(op_name, op_namespace) if op_key not in logged_ops: - logged_ops[op_key] = [evt.cuda_time_total] + logged_ops[op_key] = [evt.device_time_total] else: - logged_ops[op_key].append(evt.cuda_time_total) + logged_ops[op_key].append(evt.device_time_total) converted_ops = convert_ops(logged_ops, tags) return None if converted_ops is None else sorted(converted_ops) From e12c710bb7a1768d8e146b9c5b7b394b572abb67 Mon Sep 17 00:00:00 2001 From: Ali Hassani Date: Mon, 3 Mar 2025 19:30:23 -0500 Subject: [PATCH 3/4] Fix unit tests, formatting changes --- Makefile | 11 +- src/natten/flex.py | 237 ++++++++++++++++++++++-------------- tests/test_compute_delta.py | 1 + tests/test_fna1d.py | 3 +- tests/test_fna2d.py | 3 +- tests/test_fna3d.py | 3 +- tests/test_na1d.py | 1 + tests/test_na2d.py | 1 + tests/test_na3d.py | 1 + 9 files changed, 163 insertions(+), 98 deletions(-) diff --git a/Makefile b/Makefile index 7e13381..d087de1 100644 --- a/Makefile +++ b/Makefile @@ -55,12 +55,15 @@ uninstall: install: @echo "Installing NATTEN from source" NATTEN_CUDA_ARCH="${CUDA_ARCH}" \ - NATTEN_N_WORKERS="${WORKERS}" \ - NATTEN_WITH_CUDA="${WITH_CUDA}" \ - NATTEN_VERBOSE="${VERBOSE}" \ - pip install -v -e . 2>&1 | tee install.out + NATTEN_N_WORKERS="${WORKERS}" \ + NATTEN_WITH_CUDA="${WITH_CUDA}" \ + NATTEN_VERBOSE="${VERBOSE}" \ + pip install -v -e . 2>&1 | tee install.out test: + NATTEN_LOG_LEVEL="CRITICAL" \ + PYTORCH_NO_CUDA_MEMORY_CACHING=1 \ + CUBLAS_WORKSPACE_CONFIG=":4096:8" \ pytest -v -x ./tests style: diff --git a/src/natten/flex.py b/src/natten/flex.py index e87141c..0bffe88 100644 --- a/src/natten/flex.py +++ b/src/natten/flex.py @@ -1,4 +1,28 @@ +################################################################################################# +# Copyright (c) 2022-2024 Ali Hassani. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################################# + import functools +import math from typing import Optional, Tuple import torch @@ -9,9 +33,11 @@ CausalArg1DTypeOrDed, CausalArg2DTypeOrDed, CausalArg3DTypeOrDed, + CausalArgType, Dimension1DTypeOrDed, Dimension2DTypeOrDed, Dimension3DTypeOrDed, + DimensionType, ) from .utils import check_all_args @@ -22,64 +48,68 @@ def get_flex_attention_compiled(): @functools.lru_cache(maxsize=None) -def get_block_mask( - num_dimension: int, - image_shape: Tuple[int], - kernel_size: Tuple[int], - dilation: Tuple[int], - is_causal: Tuple[bool], +def get_na_flex_mask( + na_dim: int, + input_size: DimensionType, + kernel_size: DimensionType, + dilation: DimensionType, + is_causal: CausalArgType, ): - def get_location_1d(idx: IntTensor) -> Tuple[IntTensor]: + def index_to_coord_1d(idx: IntTensor) -> Tuple[IntTensor]: + assert len(input_size) == 1 return (idx,) - def get_location_2d(idx: IntTensor) -> Tuple[IntTensor, IntTensor]: - return (idx // image_shape[1], idx % image_shape[1]) # type: ignore + def index_to_coord_2d(idx: IntTensor) -> Tuple[IntTensor, IntTensor]: + assert len(input_size) == 2 + return (idx // input_size[1], idx % input_size[1]) # type: ignore - def get_location_3d(idx: IntTensor) -> Tuple[IntTensor, IntTensor, IntTensor]: + def index_to_coord_3d(idx: IntTensor) -> Tuple[IntTensor, IntTensor, IntTensor]: + assert len(input_size) == 3 return ( - idx // image_shape[2] // image_shape[1], # type: ignore - (idx // image_shape[2]) % image_shape[1], # type: ignore - idx % image_shape[2], # type: ignore + idx // input_size[2] // input_size[1], # type: ignore + (idx // input_size[2]) % input_size[1], # type: ignore + idx % input_size[2], # type: ignore ) - get_location = { - 1: get_location_1d, - 2: get_location_2d, - 3: get_location_3d, - }[num_dimension] + index_to_coord = { + 1: index_to_coord_1d, + 2: index_to_coord_2d, + 3: index_to_coord_3d, + }[na_dim] - def natten_mask_mod( + def na_mask_mod( b: IntTensor, h: IntTensor, q_idx: IntTensor, kv_idx: IntTensor ) -> BoolTensor: - q_idx = get_location(q_idx) # type: ignore - kv_idx = get_location(kv_idx) # type: ignore + q_coord = index_to_coord(q_idx) + kv_coord = index_to_coord(kv_idx) masks = [] - for i in range(num_dimension): - dilate_kernel = kernel_size[i] * dilation[i] + for i in range(na_dim): + kernel_times_dilation = kernel_size[i] * dilation[i] if is_causal[i]: mask = ( - (q_idx[i] - kv_idx[i] >= 0) - & (q_idx[i] - kv_idx[i] < dilate_kernel) - & ((q_idx[i] % dilation[i]) == (kv_idx[i] % dilation[i])) + (q_coord[i] - kv_coord[i] >= 0) + & (q_coord[i] - kv_coord[i] < kernel_times_dilation) + & ((q_coord[i] % dilation[i]) == (kv_coord[i] % dilation[i])) ) else: - kernel_center_x = q_idx[i].clamp( - (dilate_kernel - 1) // 2, - (image_shape[i] - 1) - (dilate_kernel - 1) // 2, - ) - mask = ((kernel_center_x - kv_idx[i]).abs() <= dilate_kernel // 2) & ( - (q_idx[i] % dilation[i]) == (kv_idx[i] % dilation[i]) + kernel_center_x = q_coord[i].clamp( + (kernel_times_dilation - 1) // 2, + (input_size[i] - 1) - (kernel_times_dilation - 1) // 2, ) + mask = ( + (kernel_center_x - kv_coord[i]).abs() <= kernel_times_dilation // 2 + ) & ((q_coord[i] % dilation[i]) == (kv_coord[i] % dilation[i])) masks.append(mask) return functools.reduce(lambda x, y: x & y, masks) # type: ignore - seq_length = functools.reduce(lambda x, y: x * y, image_shape) - block_mask = create_block_mask(natten_mask_mod, 1, 1, seq_length, seq_length) # type: ignore - return block_mask + seq_length = math.prod(input_size) + return create_block_mask( + na_mask_mod, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length # type: ignore + ) def flex_na1d( @@ -90,32 +120,41 @@ def flex_na1d( dilation: Dimension1DTypeOrDed = 1, is_causal: Optional[CausalArg1DTypeOrDed] = False, ) -> torch.Tensor: - """ - Args: - query: (batch_size, seq_length, num_head, head_dim) - key: (batch_size, seq_length, num_head, head_dim) - value: (batch_size, seq_length, num_head, head_dim) - kernel_size: Union[int, Tuple[int]] - dilation: Union[int, Tuple[int]] - is_causal: Union[bool, Tuple[bool]] - """ kernel_size_, dilation_, is_causal_ = check_all_args( 1, kernel_size, dilation, is_causal ) - batch_size, seq_length, num_head, head_dim = query.shape - image_shape = (seq_length,) + if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: + raise ValueError( + "flex_na1d expects query, key, and value to be 5-dimensional tensors, " + f"got {query.shape=}, {key.shape=}, {value.shape=}." + ) + + if query.shape != key.shape or query.shape != value.shape: + raise ValueError( + "flex_na1d expects query, key, and value to have the same shape, " + f"got {query.shape=}, {key.shape=}, {value.shape=}." + ) + + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError( + "flex_na1d expects query, key, and value to have the same dtype, " + f"got {query.dtype=}, {key.dtype=}, {value.dtype=}." + ) - _query = query.transpose(1, 2) - _key = key.transpose(1, 2) - _value = value.transpose(1, 2) + batch_size, seqlen, num_heads, head_dim = query.shape + input_size = (seqlen,) - block_mask = get_block_mask(1, image_shape, kernel_size_, dilation_, is_causal_) + query_ = query.transpose(1, 2) + key_ = key.transpose(1, 2) + value_ = value.transpose(1, 2) + + na_mask = get_na_flex_mask(1, input_size, kernel_size_, dilation_, is_causal_) flex_attention_compiled = get_flex_attention_compiled() - out = flex_attention_compiled(_query, _key, _value, block_mask=block_mask) + out_ = flex_attention_compiled(query_, key_, value_, block_mask=na_mask) - out = out.transpose(1, 2) + out = out_.transpose(1, 2) return out @@ -128,35 +167,42 @@ def flex_na2d( dilation: Dimension2DTypeOrDed = 1, is_causal: Optional[CausalArg2DTypeOrDed] = False, ) -> torch.Tensor: - """ - Args: - query: (batch_size, image_height, image_width, num_head, head_dim) - key: (batch_size, image_height, image_width, num_head, head_dim) - value: (batch_size, image_height, image_width, num_head, head_dim) - kernel_size: Union[int, Tuple[int, int]] - dilation: Union[int, Tuple[int, int]] - is_causal: Union[bool, Tuple[bool, bool]] - """ kernel_size_, dilation_, is_causal_ = check_all_args( 2, kernel_size, dilation, is_causal ) - batch_size, image_height, image_width, num_head, head_dim = query.shape - seq_length = image_height * image_width - image_shape = (image_height, image_width) + if query.dim() != 5 or key.dim() != 5 or value.dim() != 5: + raise ValueError( + "flex_na2d expects query, key, and value to be 5-dimensional tensors, " + f"got {query.shape=}, {key.shape=}, {value.shape=}." + ) - _query = query.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2) - _key = key.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2) - _value = value.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2) + if query.shape != key.shape or query.shape != value.shape: + raise ValueError( + "flex_na2d expects query, key, and value to have the same shape, " + f"got {query.shape=}, {key.shape=}, {value.shape=}." + ) - block_mask = get_block_mask(2, image_shape, kernel_size_, dilation_, is_causal_) + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError( + "flex_na2d expects query, key, and value to have the same dtype, " + f"got {query.dtype=}, {key.dtype=}, {value.dtype=}." + ) + + batch_size, seqlen_1, seqlen_2, num_heads, head_dim = query.shape + seq_length = seqlen_1 * seqlen_2 + input_size = (seqlen_1, seqlen_2) + + query_ = query.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2) + key_ = key.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2) + value_ = value.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2) + + na_mask = get_na_flex_mask(2, input_size, kernel_size_, dilation_, is_causal_) flex_attention_compiled = get_flex_attention_compiled() - out = flex_attention_compiled(_query, _key, _value, block_mask=block_mask) + out_ = flex_attention_compiled(query_, key_, value_, block_mask=na_mask) - out = out.transpose(1, 2).view( - batch_size, image_height, image_width, num_head, head_dim - ) + out = out_.transpose(1, 2).view(batch_size, seqlen_1, seqlen_2, num_heads, head_dim) return out @@ -169,34 +215,43 @@ def flex_na3d( dilation: Dimension3DTypeOrDed = 1, is_causal: Optional[CausalArg3DTypeOrDed] = False, ) -> torch.Tensor: - """ - Args: - query: (batch_size, image_height, image_width, num_head, head_dim) - key: (batch_size, image_height, image_width, num_head, head_dim) - value: (batch_size, image_height, image_width, num_head, head_dim) - kernel_size: Union[int, Tuple[int, int]] - dilation: Union[int, Tuple[int, int]] - is_causal: Union[bool, Tuple[bool, bool]] - """ kernel_size_, dilation_, is_causal_ = check_all_args( 3, kernel_size, dilation, is_causal ) - batch_size, image_depth, image_height, image_width, num_head, head_dim = query.shape - seq_length = image_depth * image_height * image_width - image_shape = (image_depth, image_height, image_width) + if query.dim() != 6 or key.dim() != 6 or value.dim() != 6: + raise ValueError( + "flex_na3d expects query, key, and value to be 6-dimensional tensors, " + f"got {query.shape=}, {key.shape=}, {value.shape=}." + ) + + if query.shape != key.shape or query.shape != value.shape: + raise ValueError( + "flex_na3d expects query, key, and value to have the same shape, " + f"got {query.shape=}, {key.shape=}, {value.shape=}." + ) + + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError( + "flex_na3d expects query, key, and value to have the same dtype, " + f"got {query.dtype=}, {key.dtype=}, {value.dtype=}." + ) + + batch_size, seqlen_0, seqlen_1, seqlen_2, num_heads, head_dim = query.shape + seq_length = seqlen_0 * seqlen_1 * seqlen_2 + input_size = (seqlen_0, seqlen_1, seqlen_2) - _query = query.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2) - _key = key.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2) - _value = value.view(batch_size, seq_length, num_head, head_dim).transpose(1, 2) + query_ = query.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2) + key_ = key.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2) + value_ = value.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2) - block_mask = get_block_mask(3, image_shape, kernel_size_, dilation_, is_causal_) + na_mask = get_na_flex_mask(3, input_size, kernel_size_, dilation_, is_causal_) flex_attention_compiled = get_flex_attention_compiled() - out = flex_attention_compiled(_query, _key, _value, block_mask=block_mask) + out_ = flex_attention_compiled(query_, key_, value_, block_mask=na_mask) - out = out.transpose(1, 2).view( - batch_size, image_depth, image_height, image_width, num_head, head_dim + out = out_.transpose(1, 2).view( + batch_size, seqlen_0, seqlen_1, seqlen_2, num_heads, head_dim ) return out diff --git a/tests/test_compute_delta.py b/tests/test_compute_delta.py index 26a650d..f7d5950 100644 --- a/tests/test_compute_delta.py +++ b/tests/test_compute_delta.py @@ -55,6 +55,7 @@ def _reset_everything(): torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False + torch.cuda.empty_cache() HAS_HALF = has_half() diff --git a/tests/test_fna1d.py b/tests/test_fna1d.py index c393c47..85cbae2 100644 --- a/tests/test_fna1d.py +++ b/tests/test_fna1d.py @@ -67,9 +67,10 @@ def _reset_everything(): torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False torch.manual_seed(42) + torch.cuda.empty_cache() # Attention merge recompilation requires this - torch._dynamo.config.cache_size_limit = 64 + torch._dynamo.config.cache_size_limit = 1024 HAS_HALF = has_half() diff --git a/tests/test_fna2d.py b/tests/test_fna2d.py index d46d31d..bec7ac6 100644 --- a/tests/test_fna2d.py +++ b/tests/test_fna2d.py @@ -68,9 +68,10 @@ def _reset_everything(): torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False torch.manual_seed(42) + torch.cuda.empty_cache() # Attention merge recompilation requires this - torch._dynamo.config.cache_size_limit = 64 + torch._dynamo.config.cache_size_limit = 1024 HAS_HALF = has_half() diff --git a/tests/test_fna3d.py b/tests/test_fna3d.py index 7cef58c..878bfc3 100644 --- a/tests/test_fna3d.py +++ b/tests/test_fna3d.py @@ -68,9 +68,10 @@ def _reset_everything(): torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False torch.manual_seed(42) + torch.cuda.empty_cache() # Attention merge recompilation requires this - torch._dynamo.config.cache_size_limit = 64 + torch._dynamo.config.cache_size_limit = 1024 HAS_HALF = has_half() diff --git a/tests/test_na1d.py b/tests/test_na1d.py index 243a9d4..91ae4c0 100644 --- a/tests/test_na1d.py +++ b/tests/test_na1d.py @@ -70,6 +70,7 @@ def _reset_everything(): torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False + torch.cuda.empty_cache() HAS_GEMM = has_gemm() diff --git a/tests/test_na2d.py b/tests/test_na2d.py index 938964a..10b842b 100644 --- a/tests/test_na2d.py +++ b/tests/test_na2d.py @@ -73,6 +73,7 @@ def _reset_everything(): torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False + torch.cuda.empty_cache() HAS_GEMM = has_gemm() diff --git a/tests/test_na3d.py b/tests/test_na3d.py index 599758a..288b182 100644 --- a/tests/test_na3d.py +++ b/tests/test_na3d.py @@ -61,6 +61,7 @@ def _reset_everything(): torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False + torch.cuda.empty_cache() HAS_HALF = has_half() From 5f52f8a1f6595b9fabbd50a12ccc7948c1a25467 Mon Sep 17 00:00:00 2001 From: Fengzhe Zhou Date: Tue, 4 Mar 2025 02:27:15 -0500 Subject: [PATCH 4/4] Update docs, remove python cache --- CHANGELOG.md | 3 +++ docs/frontend.md | 18 ++++++++++++++++++ src/natten/flex.py | 5 ++--- 3 files changed, 23 insertions(+), 3 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index ddcff43..fb85f62 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ ## [Main branch] +## [0.17.5.dev0] - 2025-03-04 +* Integrate flex attention backend. + ## [0.17.4] - 2025-01-28 * Support for additional KV tokens in FNA (requires xFormers) * Adds experimental support for additional KV tokens (attend to local neighborhood, and some diff --git a/docs/frontend.md b/docs/frontend.md index 461c989..cf6666b 100644 --- a/docs/frontend.md +++ b/docs/frontend.md @@ -165,6 +165,24 @@ Future versions may offer more fine-grained control over this. For more information, refer to [KV parallelism](fna/kv-parallelism.md). +#### Using FlexAttention Backend + +NATTEN also supports FlexAttention Backend, which can be enabled as follows: + +```python +from natten import use_flex_attention + +use_flex_attention(True) +# Enable FlexAttention Backend + +use_flex_attention(False) +# Disable FlexAttention Backend (default) +``` + +FlexAttention could be potentially faster than FNA on modern GPU architectures, especially for higher dimensionals (2-D or 3-D). + +However, FlexAttention backend is still experimental, and may contain certain bugs due to its kernel implementation. Bug reports related to this backend in general are strongly appreciated. + ### Operations Operations are one level below our modules, and are intended to give you full control over the module-level details, and only use the underlying neighborhood attention operators directly. diff --git a/src/natten/flex.py b/src/natten/flex.py index 0bffe88..0f94ed4 100644 --- a/src/natten/flex.py +++ b/src/natten/flex.py @@ -42,7 +42,6 @@ from .utils import check_all_args -@functools.lru_cache(maxsize=1) def get_flex_attention_compiled(): return torch.compile(flex_attention, dynamic=False) @@ -108,7 +107,7 @@ def na_mask_mod( seq_length = math.prod(input_size) return create_block_mask( - na_mask_mod, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length # type: ignore + na_mask_mod, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length, _compile=True # type: ignore ) @@ -127,7 +126,7 @@ def flex_na1d( if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: raise ValueError( - "flex_na1d expects query, key, and value to be 5-dimensional tensors, " + "flex_na1d expects query, key, and value to be 4-dimensional tensors, " f"got {query.shape=}, {key.shape=}, {value.shape=}." )