diff --git a/CHANGELOG.md b/CHANGELOG.md index ddcff43..fb85f62 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -2,6 +2,9 @@ ## [Main branch] +## [0.17.5.dev0] - 2025-03-04 +* Integrate flex attention backend. + ## [0.17.4] - 2025-01-28 * Support for additional KV tokens in FNA (requires xFormers) * Adds experimental support for additional KV tokens (attend to local neighborhood, and some diff --git a/Makefile b/Makefile index 7e13381..d087de1 100644 --- a/Makefile +++ b/Makefile @@ -55,12 +55,15 @@ uninstall: install: @echo "Installing NATTEN from source" NATTEN_CUDA_ARCH="${CUDA_ARCH}" \ - NATTEN_N_WORKERS="${WORKERS}" \ - NATTEN_WITH_CUDA="${WITH_CUDA}" \ - NATTEN_VERBOSE="${VERBOSE}" \ - pip install -v -e . 2>&1 | tee install.out + NATTEN_N_WORKERS="${WORKERS}" \ + NATTEN_WITH_CUDA="${WITH_CUDA}" \ + NATTEN_VERBOSE="${VERBOSE}" \ + pip install -v -e . 2>&1 | tee install.out test: + NATTEN_LOG_LEVEL="CRITICAL" \ + PYTORCH_NO_CUDA_MEMORY_CACHING=1 \ + CUBLAS_WORKSPACE_CONFIG=":4096:8" \ pytest -v -x ./tests style: diff --git a/docs/frontend.md b/docs/frontend.md index 461c989..cf6666b 100644 --- a/docs/frontend.md +++ b/docs/frontend.md @@ -165,6 +165,24 @@ Future versions may offer more fine-grained control over this. For more information, refer to [KV parallelism](fna/kv-parallelism.md). +#### Using FlexAttention Backend + +NATTEN also supports FlexAttention Backend, which can be enabled as follows: + +```python +from natten import use_flex_attention + +use_flex_attention(True) +# Enable FlexAttention Backend + +use_flex_attention(False) +# Disable FlexAttention Backend (default) +``` + +FlexAttention could be potentially faster than FNA on modern GPU architectures, especially for higher dimensionals (2-D or 3-D). + +However, FlexAttention backend is still experimental, and may contain certain bugs due to its kernel implementation. Bug reports related to this backend in general are strongly appreciated. + ### Operations Operations are one level below our modules, and are intended to give you full control over the module-level details, and only use the underlying neighborhood attention operators directly. diff --git a/src/natten/__init__.py b/src/natten/__init__.py index f5e6e42..bdb65b8 100644 --- a/src/natten/__init__.py +++ b/src/natten/__init__.py @@ -114,6 +114,7 @@ "disable_gemm_na", "enable_tiled_na", "disable_tiled_na", + "use_flex_attention", ] -__version__ = "0.17.4" +__version__ = "0.17.5.dev0" diff --git a/src/natten/context.py b/src/natten/context.py index fc27a35..90593d2 100644 --- a/src/natten/context.py +++ b/src/natten/context.py @@ -46,6 +46,7 @@ class NattenContext: is_deterministic_mode_enabled: bool = False is_fused_na_enabled: bool = False is_kv_parallelism_enabled: bool = False + use_flex_attention: bool = False training_memory_preference: MemoryUsagePreference = MemoryUsagePreference.Default @@ -53,6 +54,7 @@ class NattenContext: def reset(): NattenContext.is_deterministic_mode_enabled = False NattenContext.is_fused_na_enabled = False + NattenContext.use_flex_attention = False NattenContext.is_kv_parallelism_enabled = False NattenContext.training_memory_preference = MemoryUsagePreference.Default @@ -133,9 +135,12 @@ def is_kv_parallelism_in_fused_na_enabled() -> bool: return NattenContext.is_kv_parallelism_enabled -def use_fused_na(mode: bool = True, kv_parallel: bool = True): +def use_fused_na( + mode: bool = True, kv_parallel: bool = True, use_flex_attention: bool = False +): if not mode: NattenContext.is_fused_na_enabled = False + NattenContext.use_flex_attention = False use_kv_parallelism_in_fused_na(False) return @@ -147,12 +152,21 @@ def use_fused_na(mode: bool = True, kv_parallel: bool = True): ) use_kv_parallelism_in_fused_na(kv_parallel) NattenContext.is_fused_na_enabled = True + NattenContext.use_flex_attention = use_flex_attention def is_fused_na_enabled() -> bool: return NattenContext.is_fused_na_enabled +def should_use_flex_attention() -> bool: + return NattenContext.use_flex_attention + + +def use_flex_attention() -> bool: + return use_fused_na(mode=True, use_flex_attention=True) + + use_fna = use_fused_na is_fna_enabled = is_fused_na_enabled diff --git a/src/natten/flex.py b/src/natten/flex.py new file mode 100644 index 0000000..0f94ed4 --- /dev/null +++ b/src/natten/flex.py @@ -0,0 +1,256 @@ +################################################################################################# +# Copyright (c) 2022-2024 Ali Hassani. +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. +# +################################################################################################# + +import functools +import math +from typing import Optional, Tuple + +import torch +from torch import BoolTensor, IntTensor, Tensor +from torch.nn.attention.flex_attention import create_block_mask, flex_attention + +from .types import ( + CausalArg1DTypeOrDed, + CausalArg2DTypeOrDed, + CausalArg3DTypeOrDed, + CausalArgType, + Dimension1DTypeOrDed, + Dimension2DTypeOrDed, + Dimension3DTypeOrDed, + DimensionType, +) +from .utils import check_all_args + + +def get_flex_attention_compiled(): + return torch.compile(flex_attention, dynamic=False) + + +@functools.lru_cache(maxsize=None) +def get_na_flex_mask( + na_dim: int, + input_size: DimensionType, + kernel_size: DimensionType, + dilation: DimensionType, + is_causal: CausalArgType, +): + + def index_to_coord_1d(idx: IntTensor) -> Tuple[IntTensor]: + assert len(input_size) == 1 + return (idx,) + + def index_to_coord_2d(idx: IntTensor) -> Tuple[IntTensor, IntTensor]: + assert len(input_size) == 2 + return (idx // input_size[1], idx % input_size[1]) # type: ignore + + def index_to_coord_3d(idx: IntTensor) -> Tuple[IntTensor, IntTensor, IntTensor]: + assert len(input_size) == 3 + return ( + idx // input_size[2] // input_size[1], # type: ignore + (idx // input_size[2]) % input_size[1], # type: ignore + idx % input_size[2], # type: ignore + ) + + index_to_coord = { + 1: index_to_coord_1d, + 2: index_to_coord_2d, + 3: index_to_coord_3d, + }[na_dim] + + def na_mask_mod( + b: IntTensor, h: IntTensor, q_idx: IntTensor, kv_idx: IntTensor + ) -> BoolTensor: + q_coord = index_to_coord(q_idx) + kv_coord = index_to_coord(kv_idx) + + masks = [] + for i in range(na_dim): + kernel_times_dilation = kernel_size[i] * dilation[i] + if is_causal[i]: + mask = ( + (q_coord[i] - kv_coord[i] >= 0) + & (q_coord[i] - kv_coord[i] < kernel_times_dilation) + & ((q_coord[i] % dilation[i]) == (kv_coord[i] % dilation[i])) + ) + else: + kernel_center_x = q_coord[i].clamp( + (kernel_times_dilation - 1) // 2, + (input_size[i] - 1) - (kernel_times_dilation - 1) // 2, + ) + mask = ( + (kernel_center_x - kv_coord[i]).abs() <= kernel_times_dilation // 2 + ) & ((q_coord[i] % dilation[i]) == (kv_coord[i] % dilation[i])) + + masks.append(mask) + + return functools.reduce(lambda x, y: x & y, masks) # type: ignore + + seq_length = math.prod(input_size) + return create_block_mask( + na_mask_mod, B=None, H=None, Q_LEN=seq_length, KV_LEN=seq_length, _compile=True # type: ignore + ) + + +def flex_na1d( + query: Tensor, + key: Tensor, + value: Tensor, + kernel_size: Dimension1DTypeOrDed, + dilation: Dimension1DTypeOrDed = 1, + is_causal: Optional[CausalArg1DTypeOrDed] = False, +) -> torch.Tensor: + + kernel_size_, dilation_, is_causal_ = check_all_args( + 1, kernel_size, dilation, is_causal + ) + + if query.dim() != 4 or key.dim() != 4 or value.dim() != 4: + raise ValueError( + "flex_na1d expects query, key, and value to be 4-dimensional tensors, " + f"got {query.shape=}, {key.shape=}, {value.shape=}." + ) + + if query.shape != key.shape or query.shape != value.shape: + raise ValueError( + "flex_na1d expects query, key, and value to have the same shape, " + f"got {query.shape=}, {key.shape=}, {value.shape=}." + ) + + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError( + "flex_na1d expects query, key, and value to have the same dtype, " + f"got {query.dtype=}, {key.dtype=}, {value.dtype=}." + ) + + batch_size, seqlen, num_heads, head_dim = query.shape + input_size = (seqlen,) + + query_ = query.transpose(1, 2) + key_ = key.transpose(1, 2) + value_ = value.transpose(1, 2) + + na_mask = get_na_flex_mask(1, input_size, kernel_size_, dilation_, is_causal_) + flex_attention_compiled = get_flex_attention_compiled() + out_ = flex_attention_compiled(query_, key_, value_, block_mask=na_mask) + + out = out_.transpose(1, 2) + + return out + + +def flex_na2d( + query: Tensor, + key: Tensor, + value: Tensor, + kernel_size: Dimension2DTypeOrDed, + dilation: Dimension2DTypeOrDed = 1, + is_causal: Optional[CausalArg2DTypeOrDed] = False, +) -> torch.Tensor: + + kernel_size_, dilation_, is_causal_ = check_all_args( + 2, kernel_size, dilation, is_causal + ) + + if query.dim() != 5 or key.dim() != 5 or value.dim() != 5: + raise ValueError( + "flex_na2d expects query, key, and value to be 5-dimensional tensors, " + f"got {query.shape=}, {key.shape=}, {value.shape=}." + ) + + if query.shape != key.shape or query.shape != value.shape: + raise ValueError( + "flex_na2d expects query, key, and value to have the same shape, " + f"got {query.shape=}, {key.shape=}, {value.shape=}." + ) + + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError( + "flex_na2d expects query, key, and value to have the same dtype, " + f"got {query.dtype=}, {key.dtype=}, {value.dtype=}." + ) + + batch_size, seqlen_1, seqlen_2, num_heads, head_dim = query.shape + seq_length = seqlen_1 * seqlen_2 + input_size = (seqlen_1, seqlen_2) + + query_ = query.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2) + key_ = key.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2) + value_ = value.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2) + + na_mask = get_na_flex_mask(2, input_size, kernel_size_, dilation_, is_causal_) + flex_attention_compiled = get_flex_attention_compiled() + out_ = flex_attention_compiled(query_, key_, value_, block_mask=na_mask) + + out = out_.transpose(1, 2).view(batch_size, seqlen_1, seqlen_2, num_heads, head_dim) + + return out + + +def flex_na3d( + query: Tensor, + key: Tensor, + value: Tensor, + kernel_size: Dimension3DTypeOrDed, + dilation: Dimension3DTypeOrDed = 1, + is_causal: Optional[CausalArg3DTypeOrDed] = False, +) -> torch.Tensor: + + kernel_size_, dilation_, is_causal_ = check_all_args( + 3, kernel_size, dilation, is_causal + ) + + if query.dim() != 6 or key.dim() != 6 or value.dim() != 6: + raise ValueError( + "flex_na3d expects query, key, and value to be 6-dimensional tensors, " + f"got {query.shape=}, {key.shape=}, {value.shape=}." + ) + + if query.shape != key.shape or query.shape != value.shape: + raise ValueError( + "flex_na3d expects query, key, and value to have the same shape, " + f"got {query.shape=}, {key.shape=}, {value.shape=}." + ) + + if query.dtype != key.dtype or query.dtype != value.dtype: + raise ValueError( + "flex_na3d expects query, key, and value to have the same dtype, " + f"got {query.dtype=}, {key.dtype=}, {value.dtype=}." + ) + + batch_size, seqlen_0, seqlen_1, seqlen_2, num_heads, head_dim = query.shape + seq_length = seqlen_0 * seqlen_1 * seqlen_2 + input_size = (seqlen_0, seqlen_1, seqlen_2) + + query_ = query.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2) + key_ = key.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2) + value_ = value.view(batch_size, seq_length, num_heads, head_dim).transpose(1, 2) + + na_mask = get_na_flex_mask(3, input_size, kernel_size_, dilation_, is_causal_) + flex_attention_compiled = get_flex_attention_compiled() + out_ = flex_attention_compiled(query_, key_, value_, block_mask=na_mask) + + out = out_.transpose(1, 2).view( + batch_size, seqlen_0, seqlen_1, seqlen_2, num_heads, head_dim + ) + + return out diff --git a/src/natten/functional.py b/src/natten/functional.py index 5f389dc..f45d77e 100644 --- a/src/natten/functional.py +++ b/src/natten/functional.py @@ -51,6 +51,8 @@ ) from .autotuner import autotune_fna +from .context import should_use_flex_attention +from .flex import flex_na1d, flex_na2d, flex_na3d from .nested import ( na1d_av_nested, na1d_qk_nested, @@ -1721,6 +1723,29 @@ def na1d( "Fused neighborhood attention does not support nested tensors yet." ) + if should_use_flex_attention(): + if scale is not None: + raise NotImplementedError( + "Custom attention scale is not supported in the Flex Attention backend." + ) + if rpb is not None: + raise NotImplementedError( + "RPB is not supported in the Flex Attention backend." + ) + if additional_keys is not None or additional_values is not None: + raise NotImplementedError( + "Additional keys/values is not supported in the Flex Attention backend." + ) + + return flex_na1d( + query, + key, + value, + kernel_size, + dilation, + is_causal, + ) + tiling_config_forward, tiling_config_backward = autotune_fna( 1, query, kernel_size, dilation, is_causal ) @@ -1777,6 +1802,29 @@ def na2d( "Fused neighborhood attention does not support nested tensors yet." ) + if should_use_flex_attention(): + if scale is not None: + raise NotImplementedError( + "Custom attention scale is not supported in the Flex Attention backend." + ) + if rpb is not None: + raise NotImplementedError( + "RPB is not supported in the Flex Attention backend." + ) + if additional_keys is not None or additional_values is not None: + raise NotImplementedError( + "Additional keys/values is not supported in the Flex Attention backend." + ) + + return flex_na2d( + query, + key, + value, + kernel_size, + dilation, + is_causal, + ) + tiling_config_forward, tiling_config_backward = autotune_fna( 2, query, kernel_size, dilation, is_causal ) @@ -1833,6 +1881,29 @@ def na3d( "Fused neighborhood attention does not support nested tensors yet." ) + if should_use_flex_attention(): + if scale is not None: + raise NotImplementedError( + "Custom attention scale is not supported in the Flex Attention backend." + ) + if rpb is not None: + raise NotImplementedError( + "RPB is not supported in the Flex Attention backend." + ) + if additional_keys is not None or additional_values is not None: + raise NotImplementedError( + "Additional keys/values is not supported in the Flex Attention backend." + ) + + return flex_na3d( + query, + key, + value, + kernel_size, + dilation, + is_causal, + ) + tiling_config_forward, tiling_config_backward = autotune_fna( 3, query, kernel_size, dilation, is_causal ) diff --git a/tests/test_compute_delta.py b/tests/test_compute_delta.py index 26a650d..f7d5950 100644 --- a/tests/test_compute_delta.py +++ b/tests/test_compute_delta.py @@ -55,6 +55,7 @@ def _reset_everything(): torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False + torch.cuda.empty_cache() HAS_HALF = has_half() diff --git a/tests/test_fna1d.py b/tests/test_fna1d.py index ed33cf5..85cbae2 100644 --- a/tests/test_fna1d.py +++ b/tests/test_fna1d.py @@ -33,6 +33,7 @@ use_autotuner, use_kv_parallelism_in_fused_na, ) +from natten.flex import flex_na1d from natten.functional import na1d, na1d_av, na1d_qk from natten.utils import check_all_args from natten.utils.testing import ( @@ -66,9 +67,10 @@ def _reset_everything(): torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False torch.manual_seed(42) + torch.cuda.empty_cache() # Attention merge recompilation requires this - torch._dynamo.config.cache_size_limit = 64 + torch._dynamo.config.cache_size_limit = 1024 HAS_HALF = has_half() @@ -533,6 +535,163 @@ def test_against_sdpa(self): ) +class FlexAttentionFNA1DTest(unittest.TestCase): + def setUp(self): + _reset_everything() + + def tearDown(self): + _reset_everything() + + def _test_against_cutlass_fna( + self, B, H, L, D, kernel_size, dilation, is_causal, eps, dtype + ): + kernel_size, dilation, is_causal = check_args(kernel_size, dilation, is_causal) + with torch.no_grad(): + q, k, v, d_out = ( + torch.randn((B, L, H, D), device="cuda", dtype=dtype), + torch.randn((B, L, H, D), device="cuda", dtype=dtype), + torch.randn((B, L, H, D), device="cuda", dtype=dtype), + torch.randn((B, L, H, D), device="cuda", dtype=dtype) * 0.05, + ) + + q_ref, k_ref, v_ref, d_out_ref = ( + q.clone(), + k.clone(), + v.clone(), + d_out.clone(), + ) + + # Reference + q_ref.requires_grad_(True) + k_ref.requires_grad_(True) + v_ref.requires_grad_(True) + d_out_ref.requires_grad_(True) + out_ref_ = na1d( + q_ref, + k_ref, + v_ref, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + ) + out_ref = out_ref_.data.clone().float() + + dq_ref, dk_ref, dv_ref = None, None, None + out_ref_.backward(d_out_ref) + with torch.no_grad(): + dq_ref, dk_ref, dv_ref = ( + q_ref.grad.clone().float(), + k_ref.grad.clone().float(), + v_ref.grad.clone().float(), + ) + + # Flex + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + d_out.requires_grad_(True) + + out_ = flex_na1d( + q, + k, + v, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + ) + out = out_.data.clone().float() + + dq, dk, dv = None, None, None + out_.backward(d_out) + with torch.no_grad(): + dq, dk, dv = ( + q.grad.clone().float(), + k.grad.clone().float(), + v.grad.clone().float(), + ) + + torch.testing.assert_close(out, out_ref, atol=eps, rtol=0) + torch.testing.assert_close(dq, dq_ref, atol=eps, rtol=0) + torch.testing.assert_close(dk, dk_ref, atol=eps, rtol=0) + torch.testing.assert_close(dv, dv_ref, atol=eps, rtol=0) + + def _test_all_dtypes( + self, + B, + H, + L, + D, + kernel_size, + dilation, + is_causal=None, + ): + self._test_against_cutlass_fna( + B=B, + H=H, + L=L, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + dtype=torch.float32, + eps=1e-2, + ) + if HAS_HALF: + self._test_against_cutlass_fna( + B=B, + H=H, + L=L, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + dtype=torch.float16, + eps=1e-1, + ) + if HAS_BFLOAT: + self._test_against_cutlass_fna( + B=B, + H=H, + L=L, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + dtype=torch.bfloat16, + eps=1e-1, + ) + + @skip_if_cuda_is_not_supported() + @skip_if_fna_is_not_supported() + def test_against_cutlass_fna(self): + problem_sizes = [ + (1, 1, 3, 16, 3, 1), + (1, 1, 16, 32, 3, 1), + (1, 2, 33, 32, 15, 1), + (1, 2, 33, 64, 15, 2), + (4, 3, 256, 64, 255, 1), + (2, 2, 4096, 64, 2047, 1), + (2, 4, 4096, 64, 2047, 2), + (4, 3, 5000, 64, 511, 8), + (4, 3, 5000, 64, 255, 16), + (1, 12, 512, 64, 255, 1), + # TODO: these will fail on most non-A100/H100 cards due to the 99KB shmem limit + # (4, 24, 512, 128, 99, 1), + # (1, 48, 512, 256, 45, 4), + ] + for B, H, L, D, kernel_size, dilation in problem_sizes: + for is_causal in [False, True]: + self._test_all_dtypes( + B=B, + H=H, + L=L, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + ) + + if __name__ == "__main__": torch.manual_seed(42) unittest.main() diff --git a/tests/test_fna2d.py b/tests/test_fna2d.py index 7f25261..bec7ac6 100644 --- a/tests/test_fna2d.py +++ b/tests/test_fna2d.py @@ -34,6 +34,7 @@ use_autotuner, use_kv_parallelism_in_fused_na, ) +from natten.flex import flex_na2d from natten.functional import na2d, na2d_av, na2d_qk from natten.utils import check_all_args from natten.utils.testing import ( @@ -67,9 +68,10 @@ def _reset_everything(): torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False torch.manual_seed(42) + torch.cuda.empty_cache() # Attention merge recompilation requires this - torch._dynamo.config.cache_size_limit = 64 + torch._dynamo.config.cache_size_limit = 1024 HAS_HALF = has_half() @@ -550,6 +552,182 @@ def test_against_sdpa(self): ) +class FlexAttentionFNA2DTest(unittest.TestCase): + def setUp(self): + _reset_everything() + + def tearDown(self): + _reset_everything() + + def _test_against_cutlass_fna( + self, B, H, X, Y, D, kernel_size, dilation, is_causal, eps, dtype + ): + kernel_size, dilation, is_causal = check_args(kernel_size, dilation, is_causal) + with torch.no_grad(): + q, k, v, d_out = ( + torch.randn((B, X, Y, H, D), device="cuda", dtype=dtype), + torch.randn((B, X, Y, H, D), device="cuda", dtype=dtype), + torch.randn((B, X, Y, H, D), device="cuda", dtype=dtype), + torch.randn((B, X, Y, H, D), device="cuda", dtype=dtype) * 0.05, + ) + + q_ref, k_ref, v_ref, d_out_ref = ( + q.clone(), + k.clone(), + v.clone(), + d_out.clone(), + ) + + # Reference + q_ref.requires_grad_(True) + k_ref.requires_grad_(True) + v_ref.requires_grad_(True) + d_out_ref.requires_grad_(True) + out_ref_ = na2d( + q_ref, + k_ref, + v_ref, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + ) + out_ref = out_ref_.data.clone().float() + + dq_ref, dk_ref, dv_ref = None, None, None + out_ref_.backward(d_out_ref) + with torch.no_grad(): + dq_ref, dk_ref, dv_ref = ( + q_ref.grad.clone().float(), + k_ref.grad.clone().float(), + v_ref.grad.clone().float(), + ) + + # Flex + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + d_out.requires_grad_(True) + + out_ = flex_na2d( + q, + k, + v, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + ) + out = out_.data.clone().float() + + dq, dk, dv = None, None, None + out_.backward(d_out) + with torch.no_grad(): + dq, dk, dv = ( + q.grad.clone().float(), + k.grad.clone().float(), + v.grad.clone().float(), + ) + + torch.testing.assert_close(out, out_ref, atol=eps, rtol=0) + torch.testing.assert_close(dq, dq_ref, atol=eps, rtol=0) + torch.testing.assert_close(dk, dk_ref, atol=eps, rtol=0) + torch.testing.assert_close(dv, dv_ref, atol=eps, rtol=0) + + def _test_all_dtypes( + self, + B, + H, + X, + Y, + D, + kernel_size, + dilation, + is_causal=None, + ): + self._test_against_cutlass_fna( + B=B, + H=H, + X=X, + Y=Y, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + dtype=torch.float32, + eps=1e-2, + ) + if HAS_HALF: + self._test_against_cutlass_fna( + B=B, + H=H, + X=X, + Y=Y, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + dtype=torch.float16, + eps=1e-1, + ) + if HAS_BFLOAT: + self._test_against_cutlass_fna( + B=B, + H=H, + X=X, + Y=Y, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + dtype=torch.bfloat16, + eps=1e-1, + ) + + @skip_if_cuda_is_not_supported() + @skip_if_fna_is_not_supported() + def test_against_cutlass_fna(self): + problem_sizes = [ + (1, 1, 3, 3, 16, 3, 3, 1, 1), + (1, 1, 8, 10, 16, 3, 5, 1, 2), + (1, 2, 15, 20, 32, 5, 15, 3, 1), + (1, 2, 17, 19, 32, 9, 7, 1, 1), + (1, 2, 17, 19, 32, 9, 7, 1, 2), + (4, 3, 32, 32, 32, 31, 31, 1, 1), + (2, 2, 32, 64, 64, 25, 31, 1, 2), + (2, 4, 64, 128, 64, 55, 101, 1, 1), + (2, 4, 64, 128, 64, 21, 29, 3, 4), + # TODO: these will fail on most non-A100/H100 cards due to the 99KB shmem limit + # (4, 3, 56, 56, 128, 7, 7, 2, 4), + # (4, 3, 28, 46, 128, 11, 13, 1, 1), + (4, 3, 56, 56, 64, 7, 7, 2, 4), + (4, 3, 28, 46, 64, 11, 13, 1, 1), + ] + for ( + B, + H, + X, + Y, + D, + kernel_size_h, + kernel_size_w, + dilation_h, + dilation_w, + ) in problem_sizes: + for causal_h, causal_w in product([True, False], [True, False]): + kernel_size = (kernel_size_h, kernel_size_w) + dilation = (dilation_h, dilation_w) + is_causal = (causal_h, causal_w) + self._test_all_dtypes( + B=B, + H=H, + X=X, + Y=Y, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + ) + + if __name__ == "__main__": torch.manual_seed(42) unittest.main() diff --git a/tests/test_fna3d.py b/tests/test_fna3d.py index 18fee35..878bfc3 100644 --- a/tests/test_fna3d.py +++ b/tests/test_fna3d.py @@ -34,6 +34,7 @@ use_autotuner, use_kv_parallelism_in_fused_na, ) +from natten.flex import flex_na3d from natten.functional import na3d, na3d_av, na3d_qk from natten.utils import check_all_args from natten.utils.testing import ( @@ -67,9 +68,10 @@ def _reset_everything(): torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False torch.manual_seed(42) + torch.cuda.empty_cache() # Attention merge recompilation requires this - torch._dynamo.config.cache_size_limit = 64 + torch._dynamo.config.cache_size_limit = 1024 HAS_HALF = has_half() @@ -559,6 +561,184 @@ def test_against_sdpa(self): ) +class FlexAttentionFNA3DTest(unittest.TestCase): + def setUp(self): + _reset_everything() + + def tearDown(self): + _reset_everything() + + def _test_against_cutlass_fna( + self, B, H, X, Y, Z, D, kernel_size, dilation, is_causal, eps, dtype + ): + kernel_size, dilation, is_causal = check_args(kernel_size, dilation, is_causal) + with torch.no_grad(): + q, k, v, d_out = ( + torch.randn((B, X, Y, Z, H, D), device="cuda", dtype=dtype), + torch.randn((B, X, Y, Z, H, D), device="cuda", dtype=dtype), + torch.randn((B, X, Y, Z, H, D), device="cuda", dtype=dtype), + torch.randn((B, X, Y, Z, H, D), device="cuda", dtype=dtype) * 0.05, + ) + + q_ref, k_ref, v_ref, d_out_ref = ( + q.clone(), + k.clone(), + v.clone(), + d_out.clone(), + ) + + # Reference + q_ref.requires_grad_(True) + k_ref.requires_grad_(True) + v_ref.requires_grad_(True) + d_out_ref.requires_grad_(True) + out_ref_ = na3d( + q_ref, + k_ref, + v_ref, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + ) + out_ref = out_ref_.data.clone().float() + + dq_ref, dk_ref, dv_ref = None, None, None + out_ref_.backward(d_out_ref) + with torch.no_grad(): + dq_ref, dk_ref, dv_ref = ( + q_ref.grad.clone().float(), + k_ref.grad.clone().float(), + v_ref.grad.clone().float(), + ) + + # Flex + q.requires_grad_(True) + k.requires_grad_(True) + v.requires_grad_(True) + d_out.requires_grad_(True) + + out_ = flex_na3d( + q, + k, + v, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + ) + out = out_.data.clone().float() + + dq, dk, dv = None, None, None + out_.backward(d_out) + with torch.no_grad(): + dq, dk, dv = ( + q.grad.clone().float(), + k.grad.clone().float(), + v.grad.clone().float(), + ) + + torch.testing.assert_close(out, out_ref, atol=eps, rtol=0) + torch.testing.assert_close(dq, dq_ref, atol=eps, rtol=0) + torch.testing.assert_close(dk, dk_ref, atol=eps, rtol=0) + torch.testing.assert_close(dv, dv_ref, atol=eps, rtol=0) + + def _test_all_dtypes( + self, + B, + H, + X, + Y, + Z, + D, + kernel_size, + dilation, + is_causal=None, + ): + self._test_against_cutlass_fna( + B=B, + H=H, + X=X, + Y=Y, + Z=Z, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + dtype=torch.float32, + eps=1e-2, + ) + if HAS_HALF: + self._test_against_cutlass_fna( + B=B, + H=H, + X=X, + Y=Y, + Z=Z, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + dtype=torch.float16, + eps=1e-1, + ) + if HAS_BFLOAT: + self._test_against_cutlass_fna( + B=B, + H=H, + X=X, + Y=Y, + Z=Z, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + dtype=torch.bfloat16, + eps=1e-1, + ) + + @skip_if_cuda_is_not_supported() + @skip_if_fna_is_not_supported() + def test_against_cutlass_fna(self): + problem_sizes = [ + (1, 1, 3, 3, 3, 16, 3, 3, 3, 1, 1, 1), + (1, 2, 6, 8, 12, 16, 5, 7, 11, 1, 1, 1), + (1, 4, 6, 8, 12, 32, 3, 3, 3, 2, 2, 4), + (2, 2, 6, 8, 12, 32, 3, 3, 3, 1, 1, 1), + (1, 12, 32, 8, 8, 64, 7, 5, 5, 2, 1, 1), + (4, 8, 32, 10, 10, 64, 7, 3, 3, 1, 2, 3), + ] + for ( + B, + H, + X, + Y, + Z, + D, + kernel_size_d, + kernel_size_h, + kernel_size_w, + dilation_d, + dilation_h, + dilation_w, + ) in problem_sizes: + for causal_d, causal_h, causal_w in product( + [True, False], [True, False], [True, False] + ): + kernel_size = (kernel_size_d, kernel_size_h, kernel_size_w) + dilation = (dilation_d, dilation_h, dilation_w) + is_causal = (causal_d, causal_h, causal_w) + self._test_all_dtypes( + B=B, + H=H, + X=X, + Y=Y, + Z=Z, + D=D, + kernel_size=kernel_size, + dilation=dilation, + is_causal=is_causal, + ) + + if __name__ == "__main__": torch.manual_seed(42) unittest.main() diff --git a/tests/test_na1d.py b/tests/test_na1d.py index 243a9d4..91ae4c0 100644 --- a/tests/test_na1d.py +++ b/tests/test_na1d.py @@ -70,6 +70,7 @@ def _reset_everything(): torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False + torch.cuda.empty_cache() HAS_GEMM = has_gemm() diff --git a/tests/test_na2d.py b/tests/test_na2d.py index 938964a..10b842b 100644 --- a/tests/test_na2d.py +++ b/tests/test_na2d.py @@ -73,6 +73,7 @@ def _reset_everything(): torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False + torch.cuda.empty_cache() HAS_GEMM = has_gemm() diff --git a/tests/test_na3d.py b/tests/test_na3d.py index 599758a..288b182 100644 --- a/tests/test_na3d.py +++ b/tests/test_na3d.py @@ -61,6 +61,7 @@ def _reset_everything(): torch.backends.cudnn.benchmark = False torch.backends.cuda.matmul.allow_tf32 = False torch.backends.cudnn.allow_tf32 = False + torch.cuda.empty_cache() HAS_HALF = has_half() diff --git a/tools/profile_1d.py b/tools/profile_1d.py index 9d7b69e..3ed384e 100644 --- a/tools/profile_1d.py +++ b/tools/profile_1d.py @@ -55,6 +55,7 @@ @click.option("--fmha", is_flag=True) @click.option("--fav2", is_flag=True) @click.option("--backprop", is_flag=True) +@click.option("--flex", is_flag=True) @click.option("--add-kv", default=0) def profile_1d( batch_size: int, @@ -76,6 +77,7 @@ def profile_1d( fav2: bool, backprop: bool, add_kv: int, + flex: bool, ): dtype = torch.float32 @@ -90,7 +92,7 @@ def profile_1d( natten.libnatten.set_gemm_tf32(False) if fuse: - natten.use_fused_na() + natten.use_fused_na(True, use_flex_attention=flex) natten.use_kv_parallelism_in_fused_na() natten.set_memory_usage_preference("unrestricted") @@ -98,6 +100,8 @@ def profile_1d( natten.use_autotuner(False, False, False, False) else: natten.use_autotuner(True, True) + elif flex: + natten.use_fused_na(True, use_flex_attention=flex) func = partial(profile_na_with_torch, fuse=fuse) if fmha: diff --git a/tools/profile_2d.py b/tools/profile_2d.py index 16686f0..2b1f072 100644 --- a/tools/profile_2d.py +++ b/tools/profile_2d.py @@ -57,6 +57,7 @@ @click.option("--fmha", is_flag=True) @click.option("--fav2", is_flag=True) @click.option("--backprop", is_flag=True) +@click.option("--flex", is_flag=True) @click.option("--add-kv", default=0) def profile_2d( batch_size: int, @@ -80,6 +81,7 @@ def profile_2d( fav2: bool, backprop: bool, add_kv: int, + flex: bool, ): dtype = torch.float32 @@ -96,7 +98,7 @@ def profile_2d( natten.libnatten.set_gemm_tf32(False) if fuse: - natten.use_fused_na() + natten.use_fused_na(True, use_flex_attention=flex) natten.use_kv_parallelism_in_fused_na() natten.set_memory_usage_preference("unrestricted") @@ -104,6 +106,8 @@ def profile_2d( natten.use_autotuner(False, False, False, False) else: natten.use_autotuner(True, True) + elif flex: + natten.use_fused_na(True, use_flex_attention=flex) func = partial(profile_na_with_torch, fuse=fuse) if fmha: diff --git a/tools/profile_3d.py b/tools/profile_3d.py index 05669d9..74d7efe 100644 --- a/tools/profile_3d.py +++ b/tools/profile_3d.py @@ -55,6 +55,7 @@ @click.option("--fmha", is_flag=True) @click.option("--fav2", is_flag=True) @click.option("--backprop", is_flag=True) +@click.option("--flex", is_flag=True) @click.option("--add-kv", default=0) def profile_3d( batch_size: int, @@ -76,6 +77,7 @@ def profile_3d( fav2: bool, backprop: bool, add_kv: int, + flex: bool, ): dtype = torch.float32 @@ -85,7 +87,7 @@ def profile_3d( dtype = torch.bfloat16 if fuse: - natten.use_fused_na() + natten.use_fused_na(True, use_flex_attention=flex) natten.use_kv_parallelism_in_fused_na() natten.set_memory_usage_preference("unrestricted") @@ -93,6 +95,8 @@ def profile_3d( natten.use_autotuner(False, False, False, False) else: natten.use_autotuner(True, True) + elif flex: + natten.use_fused_na(True, use_flex_attention=flex) func = partial(profile_na_with_torch, fuse=fuse) if fmha: diff --git a/tools/utils/formatting.py b/tools/utils/formatting.py index a79732b..c949995 100644 --- a/tools/utils/formatting.py +++ b/tools/utils/formatting.py @@ -302,14 +302,14 @@ def extract_na_ops( tags[op] = tag else: assert tags[op] == tag - logged_ops[op].append(evt.cuda_time_total) - elif evt.cuda_time_total > 0: + logged_ops[op].append(evt.device_time_total) + elif evt.device_time_total > 0: op_namespace, op_name = custom_op_to_name(evt.key) op_key = CustomOp(op_name, op_namespace) if op_key not in logged_ops: - logged_ops[op_key] = [evt.cuda_time_total] + logged_ops[op_key] = [evt.device_time_total] else: - logged_ops[op_key].append(evt.cuda_time_total) + logged_ops[op_key].append(evt.device_time_total) converted_ops = convert_ops(logged_ops, tags) return None if converted_ops is None else sorted(converted_ops)