From e3d2efd758efd03bf555c98313e68c77e6542dcc Mon Sep 17 00:00:00 2001 From: Rachit Garg Date: Wed, 13 Mar 2024 05:18:24 -0700 Subject: [PATCH 01/16] add external margin (#713) Add envvar for SM margin in GEMM Signed-off-by: Rachit Garg Co-authored-by: Rachit Garg --- transformer_engine/pytorch/csrc/comm_gemm_overlap.h | 3 +++ transformer_engine/pytorch/csrc/ts_fp8_op.cpp | 13 +++++++++++-- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 827dec5010..5f8ccab334 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -18,6 +18,7 @@ #include #include "common/util/logging.h" +#include "common/util/system.h" #include "userbuffers/userbuffers.h" #define HALF_BYTES 2 @@ -112,6 +113,7 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { cudaDeviceProp prop; cudaGetDeviceProperties(&prop, 0); _math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount; + _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); output_tensor = torch::Tensor(); auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); @@ -587,6 +589,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { cudaDeviceProp prop; cudaGetDeviceProperties(&prop, 0); _math_sms = (set_sm_margin) ? prop.multiProcessorCount - num_comm_sm : prop.multiProcessorCount; + _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); _tp_size = tp_size; _aggregate2 = aggregate2; diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index b25a8cf110..71402d2001 100755 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -6,7 +6,9 @@ #include #include "extensions.h" - +#include +#include +#include "common/util/system.h" namespace { transformer_engine::DType reverse_map_dtype(int64_t dtype) { @@ -316,6 +318,13 @@ at::Tensor te_gemm_ts(at::Tensor A, bool accumulate_arg = static_cast(accumulate); bool use_split_accumulator_arg = static_cast(use_split_accumulator); + // Set an external SM Margin to all the GEMMs. + // This comes in handy when DP is overlapped with GEMMs + cudaDeviceProp prop; + cudaGetDeviceProperties(&prop, 0); + int num_math_sms = prop.multiProcessorCount \ + - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); + if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; @@ -342,7 +351,7 @@ at::Tensor te_gemm_ts(at::Tensor A, workspaceSize_arg, accumulate_arg, use_split_accumulator_arg, - 0); + num_math_sms); return D; } From 2d0ab27f6aa2982e35ad371c76aa69ec437c0a49 Mon Sep 17 00:00:00 2001 From: Santosh Bhavani Date: Wed, 13 Mar 2024 16:19:09 -0500 Subject: [PATCH 02/16] Update README - Latest News (#718) Update README.rst - Latest News Added an entry to Latest News section Signed-off-by: Santosh Bhavani --- README.rst | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/README.rst b/README.rst index 8a4443f699..de3a331d10 100644 --- a/README.rst +++ b/README.rst @@ -13,7 +13,7 @@ Transformer Engine Latest News ================== - +* [03/2024] `FP8 Training Support in SageMaker Model Parallelism Library `_ * [12/2023] `New NVIDIA NeMo Framework Features and NVIDIA H200 `_ .. image:: docs/examples/H200-NeMo-performance.png @@ -226,7 +226,7 @@ Transformer Engine has been integrated with popular LLM frameworks such as: * `NVIDIA JAX Toolbox `_ * `NVIDIA Megatron-LM `_ * `NVIDIA NeMo Framework `_ -* `Amazon SageMaker Model Parallel Library ` +* `Amazon SageMaker Model Parallel Library `_ * `Colossal-AI `_ - Coming soon! * `PeriFlow `_ - Coming soon! * `GPT-NeoX `_ - Coming soon! From ffa24475a97f9223659effbcf4ccda6d1adb9a18 Mon Sep 17 00:00:00 2001 From: Keshav Balasubramanian Date: Thu, 14 Mar 2024 13:55:34 -0700 Subject: [PATCH 03/16] Ln force no weight sharding (#715) * disallow sharding of layernorm learnable parameters; force duplication Signed-off-by: Keshav * fix tests and support tensors for gamma/beta in layernorms Signed-off-by: Keshav * reverting Signed-off-by: Keshav * added tests for rank-1 gamma/beta sharding Signed-off-by: Keshav * fix lint errors Signed-off-by: Keshav --------- Signed-off-by: Keshav --- tests/jax/distributed_test_base.py | 2 +- tests/jax/test_distributed_layernorm.py | 77 +++++++++++++++------ transformer_engine/jax/cpp_extensions.py | 86 ++++++++++++++++++++---- 3 files changed, 129 insertions(+), 36 deletions(-) diff --git a/tests/jax/distributed_test_base.py b/tests/jax/distributed_test_base.py index bf77622d50..1360af19fa 100644 --- a/tests/jax/distributed_test_base.py +++ b/tests/jax/distributed_test_base.py @@ -18,7 +18,7 @@ def generate_configs(): if is_devices_enough(2): configs.append([2, (2,), ('dp'), MeshResource(dp_resource='dp')]) configs.append([2, (2,), ('tp'), MeshResource(tp_resource='tp')]) - + if is_devices_enough(4): TP_size = 2 DP_size = 2 diff --git a/tests/jax/test_distributed_layernorm.py b/tests/jax/test_distributed_layernorm.py index a72639e25b..3aa5b9ae20 100644 --- a/tests/jax/test_distributed_layernorm.py +++ b/tests/jax/test_distributed_layernorm.py @@ -2,6 +2,7 @@ # # See LICENSE for license information. +import warnings import pytest import jax @@ -20,7 +21,7 @@ class TestDistributedLayernorm: - def generate_inputs(self, shape, mesh_resource, dtype): + def generate_inputs(self, shape, mesh_resource, dtype, shard_weights): weight_shape = (shape[-1],) x = random.normal(random.PRNGKey(1124), shape, dtype=dtype) @@ -34,7 +35,7 @@ def generate_inputs(self, shape, mesh_resource, dtype): else: raise NotImplementedError - g_pspec = b_pspec = PartitionSpec(None) + g_pspec = b_pspec = PartitionSpec(mesh_resource.dp_resource) if shard_weights else PartitionSpec(None) return (x, gamma, beta), (x_pspec, g_pspec, b_pspec) @@ -54,8 +55,9 @@ def generate_collectives_count_ref(self, mesh_resource, ln_type, shape, dtype): @pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]]) @pytest.mark.parametrize('dtype', DTYPES) @pytest.mark.parametrize('zero_centered_gamma', [False, True]) + @pytest.mark.parametrize('shard_weights', [False, True]) def test_layernorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, - zero_centered_gamma): + zero_centered_gamma, shard_weights): epsilon = 1e-6 ln_type = 'layernorm' @@ -74,7 +76,7 @@ def ref_func(x, gamma, beta): return jnp.mean(output) (x, gamma, beta), (x_pspec, g_pspec, b_pspec) = \ - self.generate_inputs(data_shape, mesh_resource, dtype) + self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights) collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type, data_shape, dtype) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) @@ -84,19 +86,35 @@ def ref_func(x, gamma, beta): gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) beta_ = jax.device_put(beta, NamedSharding(mesh, b_pspec)) - compare_ops(target_func, - ref_func, [x_, gamma_, beta_], - collective_count_ref, - grad_args=(0, 1, 2), - metric_fwd_dtype=dtype, - metric_bwd_dtype=dtype, - in_shardings=(x_pspec, g_pspec, b_pspec), - out_shardings=(None, (x_pspec, g_pspec, b_pspec))) + with warnings.catch_warnings(record=True) as warns: + try: + compare_ops(target_func, + ref_func, [x_, gamma_, beta_], + collective_count_ref, + grad_args=(0, 1, 2), + metric_fwd_dtype=dtype, + metric_bwd_dtype=dtype, + in_shardings=(x_pspec, g_pspec, b_pspec), + out_shardings=(None, (x_pspec, g_pspec, b_pspec))) + except AssertionError as err: + # Layernorm should still produce the correct numerical result with + # gamma/beta sharded. However, the collective count may not be the same + # when XLA is forced to unshard gamma and/or beta. We can catch + # and ignore that specific error here. + if (g_pspec[-1] is None and b_pspec[-1] is None) or "Expected collective count" not in str(err): + raise err + finally: + for w in warns: + assert "Enforcing no sharding of parameters hidden dim!" in str(w), ( + "Layernorm primitive did not raise the correct warning for " + "unsupported sharding of gamma and/or beta" + ) @pytest.mark.parametrize('device_count,mesh_shape,mesh_axes,mesh_resource', generate_configs()) @pytest.mark.parametrize('data_shape', [[32, 128, 1024], [32, 1024]]) @pytest.mark.parametrize('dtype', DTYPES) - def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype): + @pytest.mark.parametrize('shard_weights', [False, True]) + def test_rmsnorm(self, device_count, mesh_shape, mesh_axes, mesh_resource, data_shape, dtype, shard_weights): epsilon = 1e-6 ln_type = 'rmsnorm' @@ -111,7 +129,7 @@ def ref_func(x, gamma): return jnp.mean(output) (x, gamma, _), (x_pspec, g_pspec, _) = \ - self.generate_inputs(data_shape, mesh_resource, dtype) + self.generate_inputs(data_shape, mesh_resource, dtype, shard_weights) collective_count_ref = self.generate_collectives_count_ref(mesh_resource, ln_type, data_shape, dtype) devices = np.asarray(jax.devices()[:device_count]).reshape(*mesh_shape) @@ -120,11 +138,26 @@ def ref_func(x, gamma): x_ = jax.device_put(x, NamedSharding(mesh, x_pspec)) gamma_ = jax.device_put(gamma, NamedSharding(mesh, g_pspec)) - compare_ops(target_func, - ref_func, [x_, gamma_], - collective_count_ref, - grad_args=(0, 1), - metric_fwd_dtype=dtype, - metric_bwd_dtype=dtype, - in_shardings=(x_pspec, g_pspec), - out_shardings=(None, (x_pspec, g_pspec))) + with warnings.catch_warnings(record=True) as warns: + try: + compare_ops(target_func, + ref_func, [x_, gamma_], + collective_count_ref, + grad_args=(0, 1), + metric_fwd_dtype=dtype, + metric_bwd_dtype=dtype, + in_shardings=(x_pspec, g_pspec), + out_shardings=(None, (x_pspec, g_pspec))) + except AssertionError as err: + # RmsNorm should still produce the correct numerical result with + # gamma/beta sharded. However, the collective count may not be the same + # when XLA is forced to unshard gamma. We can catch + # and ignore that specific error here. + if g_pspec[-1] is None or "Expected collective count" not in str(err): + raise err + finally: + for w in warns: + assert "Enforcing no sharding of parameters hidden dim!" in str(w), ( + "RmsNorm primitive did not raise the correct warning for " + "unsupported sharding of gamma and/or beta" + ) diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index d63f4ceeca..ba1d44318c 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -453,9 +453,21 @@ def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ f"and hurt performance." ) + if g_spec[-1] is not None: + warnings.warn( + f"{LayerNormFwdPrimitive.name} does not support sharding of parameter gamma " \ + f"Enforcing no sharding of parameters hidden dim! " \ + ) + if b_spec[-1] is not None: + warnings.warn( + f"{LayerNormFwdPrimitive.name} does not support sharding of parameter beta " \ + f"Enforcing no sharding of parameters hidden dim! " \ + ) + + x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) - g_sharding = NamedSharding(mesh, PartitionSpec(*g_spec)) - b_sharding = NamedSharding(mesh, PartitionSpec(*b_spec)) + g_sharding = NamedSharding(mesh, PartitionSpec(None)) + b_sharding = NamedSharding(mesh, PartitionSpec(None)) out_sharding = x_sharding mu_sharding = rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) @@ -628,8 +640,15 @@ def infer_sharding_from_operands(zero_centered_gamma, epsilon, mesh, arg_infos, f"and hurt performance." ) g_b_spec = get_padded_spec(arg_infos[4]) + if g_b_spec[-1] is not None: + warnings.warn( + f"{LayerNormBwdPrimitive.name} does not support sharding of gradients " \ + f"of gamma and beta of Layernorm " \ + f"Enforcing no sharding of parameters hidden dim! " \ + ) + dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) - dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(*g_b_spec)) + dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(None)) return dx_sharding, dgamma_sharding, dbeta_sharding @staticmethod @@ -643,12 +662,19 @@ def partition(zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): f"and hurt performance." ) g_b_spec = get_padded_spec(arg_infos[4]) + if g_b_spec[-1] is not None: + warnings.warn( + f"{LayerNormBwdPrimitive.name} does not support sharding of gradients " \ + f"of gamma and beta of Layernorm " \ + f"Enforcing no sharding of parameters hidden dim! " \ + ) + dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) - dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(*g_b_spec)) + dgamma_sharding = dbeta_sharding = NamedSharding(mesh, PartitionSpec(None)) out_shardings = dx_sharding, dgamma_sharding, dbeta_sharding x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding. mu_shardings = (NamedSharding(mesh, PartitionSpec(*x_spec[:-1])),) * 2 - arg_shardings = (*x_shardings, *mu_shardings, NamedSharding(mesh, PartitionSpec(*g_b_spec))) + arg_shardings = (*x_shardings, *mu_shardings, NamedSharding(mesh, PartitionSpec(None))) def sharded_impl(dz, x, mu, rsigma, gamma): local_dx, local_dgamma, local_dbeta = \ @@ -828,8 +854,14 @@ def partition(epsilon, mesh, arg_infos, result_infos): f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ f"and hurt performance." ) + if g_spec[-1] is not None: + warnings.warn( + f"{RmsNormFwdPrimitive.name} does not support sharding of parameter gamma " \ + f"Enforcing no sharding of parameters hidden dim! " \ + ) + x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) - g_sharding = NamedSharding(mesh, PartitionSpec(*g_spec)) + g_sharding = NamedSharding(mesh, PartitionSpec(None)) out_sharding = x_sharding rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) arg_shardings = (x_sharding, g_sharding) @@ -982,8 +1014,13 @@ def infer_sharding_from_operands(epsilon, mesh, arg_infos, result_infos): f"and hurt performance." ) g_spec = get_padded_spec(arg_infos[3]) + if g_spec[-1] is not None: + warnings.warn( + f"{RmsNormBwdPrimitive.name} does not support sharding of parameter gamma " \ + f"Enforcing no sharding of parameters hidden dim! " \ + ) dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) - dgamma_sharding = NamedSharding(mesh, PartitionSpec(*g_spec)) + dgamma_sharding = NamedSharding(mesh, PartitionSpec(None)) return dx_sharding, dgamma_sharding @staticmethod @@ -997,12 +1034,17 @@ def partition(epsilon, mesh, arg_infos, result_infos): f"and hurt performance." ) g_spec = get_padded_spec(arg_infos[3]) + if g_spec[-1] is not None: + warnings.warn( + f"{RmsNormBwdPrimitive.name} does not support sharding of parameter gamma " \ + f"Enforcing no sharding of parameters hidden dim! " \ + ) dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) - dgamma_sharding = NamedSharding(mesh, PartitionSpec(*g_spec)) + dgamma_sharding = NamedSharding(mesh, PartitionSpec(None)) out_shardings = dx_sharding, dgamma_sharding x_shardings = (dx_sharding,) * 2 # dz and x should have the same sharding. rsigma_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1])) - arg_shardings = (*x_shardings, rsigma_sharding, NamedSharding(mesh, PartitionSpec(*g_spec))) + arg_shardings = (*x_shardings, rsigma_sharding, NamedSharding(mesh, PartitionSpec(None))) def sharded_impl(dz, x, rsigma, gamma): local_dx, local_dgamma = \ @@ -4336,15 +4378,27 @@ def infer_sharding_from_operands(out_dtype, zero_centered_gamma, epsilon, mesh, def partition(out_dtype, zero_centered_gamma, epsilon, mesh, arg_infos, result_infos): del result_infos x_spec = get_padded_spec(arg_infos[0]) + g_spec = get_padded_spec(arg_infos[1]) + b_spec = get_padded_spec(arg_infos[2]) if x_spec[-1] is not None: warnings.warn( - f"Does not support to shard hidden dim in {LayerNormFwdPrimitive.name}! " \ + f"Does not support to shard hidden dim in {LayerNormFwdFp8Primitive.name}! " \ f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ f"and hurt performance." ) + if g_spec[-1] is not None: + warnings.warn( + f"{LayerNormFwdFp8Primitive.name} does not support sharding of parameter gamma " \ + f"Enforcing no sharding of parameters hidden dim! " \ + ) + if b_spec[-1] is not None: + warnings.warn( + f"{LayerNormFwdFp8Primitive.name} does not support sharding of parameter beta " \ + f"Enforcing no sharding of parameters hidden dim! " \ + ) x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) - g_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) - b_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) + g_sharding = NamedSharding(mesh, PartitionSpec(None)) + b_sharding = NamedSharding(mesh, PartitionSpec(None)) out_sharding = x_sharding mu_sharding = rsigma_sharding = NamedSharding( mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1])) @@ -4568,14 +4622,20 @@ def infer_sharding_from_operands(out_dtype, epsilon, mesh, arg_infos, result_inf def partition(out_dtype, epsilon, mesh, arg_infos, result_infos): del result_infos x_spec = get_padded_spec(arg_infos[0]) + g_spec = get_padded_spec(arg_infos[1]) if x_spec[-1] is not None: warnings.warn( f"Does not support to shard hidden dim in {RmsNormFwdFp8Primitive.name}! " \ f"Force to not shard the hidden dim, which might introduce extra collective ops, " \ f"and hurt performance." ) + if g_spec[-1] is not None: + warnings.warn( + f"{RmsNormFwdFp8Primitive.name} does not support sharding of parameter gamma " \ + f"Enforcing no sharding of parameters hidden dim! " \ + ) x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-1], None)) - g_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[1]))) + g_sharding = NamedSharding(mesh, PartitionSpec(None)) out_sharding = x_sharding rsigma_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[0])[:-1])) amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) From 1ec33ae1191ae6644365155f8e8f618145c44cd7 Mon Sep 17 00:00:00 2001 From: Rachit Garg Date: Fri, 15 Mar 2024 13:44:15 -0700 Subject: [PATCH 04/16] Rachitg/dp carveout (#722) * fix the perf regression because of constant property polling of the device Signed-off-by: Rachit Garg * Fix lint error Signed-off-by: Tim Moon --------- Signed-off-by: Rachit Garg Signed-off-by: Tim Moon Co-authored-by: Rachit Garg Co-authored-by: Tim Moon --- transformer_engine/pytorch/csrc/ts_fp8_op.cpp | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp index 71402d2001..a7217d4570 100755 --- a/transformer_engine/pytorch/csrc/ts_fp8_op.cpp +++ b/transformer_engine/pytorch/csrc/ts_fp8_op.cpp @@ -9,6 +9,7 @@ #include #include #include "common/util/system.h" +#include "common/util/cuda_runtime.h" namespace { transformer_engine::DType reverse_map_dtype(int64_t dtype) { @@ -320,10 +321,9 @@ at::Tensor te_gemm_ts(at::Tensor A, // Set an external SM Margin to all the GEMMs. // This comes in handy when DP is overlapped with GEMMs - cudaDeviceProp prop; - cudaGetDeviceProperties(&prop, 0); - int num_math_sms = prop.multiProcessorCount \ - - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); + + const int sm_count = transformer_engine::cuda::sm_count(); + int num_math_sms = sm_count - transformer_engine::getenv("NVTE_EXT_MARGIN_SM", sm_count); if (A_scale_inverse.numel()) A_scale_inverse = A_scale_inverse[A_fp8_tensor]; From a3ba77b8941fd02aee9e1b2a811558fd7e341ffe Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Mon, 18 Mar 2024 12:58:44 +0000 Subject: [PATCH 05/16] Changed VERSION to 1.6.0dev Signed-off-by: Kirthi Shankar Sivamani --- VERSION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/VERSION b/VERSION index 7cea7589b5..65babdef47 100644 --- a/VERSION +++ b/VERSION @@ -1 +1 @@ -1.5.0.dev0 +1.6.0.dev0 From 965803c9dcf889f7e9820921cb811c2bf9b91dbc Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Wed, 20 Mar 2024 12:54:13 -0700 Subject: [PATCH 06/16] Update FA version to 2.5.6 (#714) Signed-off-by: Kirthi Shankar Sivamani --- setup.py | 2 +- transformer_engine/pytorch/attention.py | 4 ++++ 2 files changed, 5 insertions(+), 1 deletion(-) diff --git a/setup.py b/setup.py index d50a9b8706..e6bfe496ff 100644 --- a/setup.py +++ b/setup.py @@ -265,7 +265,7 @@ def add_unique(l: List[str], vals: Union[str, List[str]]) -> None: # Framework-specific requirements if "pytorch" in frameworks(): - add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"]) + add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.5.6,!=2.0.9,!=2.1.0"]) add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"]) if "jax" in frameworks(): if not found_pybind11(): diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 321b1bfac8..2f9ee988e0 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -58,6 +58,7 @@ _flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version_required = packaging.version.Version("2.0.6") +_flash_attn_max_version = packaging.version.Version("2.5.6") _flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1") _flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3") _flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4") @@ -1656,6 +1657,9 @@ def __init__( assert ( _flash_attn_version >= _flash_attn_version_required ), f"FlashAttention minimum version {_flash_attn_version_required} is required." + assert ( + _flash_attn_version <= _flash_attn_max_version + ), f"FlashAttention maximum version {_flash_attn_max_version} is supported." self.norm_factor = norm_factor self.attention_dropout_ctx = attention_dropout_ctx From c38779bec64cdbb542b4f92cf7942e7d76d60c54 Mon Sep 17 00:00:00 2001 From: Sudhakar Singh Date: Wed, 20 Mar 2024 12:54:29 -0700 Subject: [PATCH 07/16] Llama accelerate tutorial (#720) * tutorial and doc fixes Signed-off-by: Sudhakar Singh * remove extra code Signed-off-by: Sudhakar Singh * fix typos Signed-off-by: Sudhakar Singh --------- Signed-off-by: Sudhakar Singh --- docs/examples/te_llama/te_llama.py | 7 +- ...tutorial_accelerate_hf_llama_with_te.ipynb | 74 ++++++++++--------- docs/examples/te_llama/utils.py | 33 ++++++--- 3 files changed, 65 insertions(+), 49 deletions(-) diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py index fba35ed30c..c73bed45b4 100644 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -21,12 +21,12 @@ from transformers.utils.hub import get_checkpoint_shard_files @contextmanager -def replace_decoder(te_decodder_cls): +def replace_decoder(te_decoder_cls): """ Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`. """ original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer - transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decodder_cls + transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls try: yield finally: @@ -56,6 +56,7 @@ def __init__(self, config, *args, **kwargs): normalization="RMSNorm", activation="swiglu", attn_input_format="bshd", + num_gqa_groups=config.num_key_value_heads, ) te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads) self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda() @@ -84,7 +85,7 @@ class is monkey-patched with `TELlamaDecoderLayer` class before """ def __new__(cls, config: LlamaConfig): - with replace_decoder(te_decodder_cls=TELlamaDecoderLayer): + with replace_decoder(te_decoder_cls=TELlamaDecoderLayer): llama_for_causal_lm = LlamaForCausalLM(config) return llama_for_causal_lm diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb index 974077de57..178922c9d2 100644 --- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb +++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb @@ -2,7 +2,7 @@ "cells": [ { "cell_type": "markdown", - "id": "1f37565e", + "id": "2cac9d39", "metadata": {}, "source": [ "# Accelerating a Hugging Face Llama 2 model with Transformer Engine\n", @@ -11,14 +11,14 @@ "\n", "Goal\n", "\n", - "This tutorial showcases how accelerate finetuning a full Llama 2 model from [Hugging Face](https://huggingface.co/meta-llama/Llama-2-7b-hf) by using `TransformerLayer` from the [Transformer Engine library](https://github.com/NVIDIA/TransformerEngine) in `BF16` and `FP8` precisions.\n", + "This tutorial showcases how to accelerate finetuning a full Llama 2 model from [Hugging Face](https://huggingface.co/meta-llama/Llama-2-7b-hf) by using `TransformerLayer` from the [Transformer Engine library](https://github.com/NVIDIA/TransformerEngine) in `BF16` and `FP8` precisions.\n", "\n", "\n" ] }, { "cell_type": "markdown", - "id": "ab4c0b82", + "id": "401f7fb1", "metadata": {}, "source": [ "## Dependencies for this tutorial\n", @@ -35,7 +35,7 @@ }, { "cell_type": "markdown", - "id": "466ff515", + "id": "33bdb5fe", "metadata": {}, "source": [ "## Table of contents\n", @@ -53,7 +53,7 @@ }, { "cell_type": "markdown", - "id": "8e84bcaa", + "id": "7645f176", "metadata": {}, "source": [ "## From \"Transformer\" to \"Llama\" \n", @@ -89,7 +89,7 @@ }, { "cell_type": "markdown", - "id": "e31303c7", + "id": "d0cfa787", "metadata": {}, "source": [ "## Hugging Face's `LlamaModel`\n", @@ -166,7 +166,7 @@ }, { "cell_type": "markdown", - "id": "686df4ef", + "id": "f4f21369", "metadata": {}, "source": [ "## [Baseline] Running HF `LlamaModel` (Precision: `BF16`)\n", @@ -190,7 +190,7 @@ }, { "cell_type": "markdown", - "id": "107a8146", + "id": "24a8d0a5", "metadata": {}, "source": [ "
\n", @@ -206,8 +206,8 @@ }, { "cell_type": "code", - "execution_count": null, - "id": "975f9184", + "execution_count": 1, + "id": "e36ff380", "metadata": {}, "outputs": [ { @@ -215,7 +215,7 @@ "output_type": "stream", "text": [ "10 finetuning steps complete!\n", - "Average time taken per step: 289 milliseconds\n" + "Average time taken per step: 315 milliseconds\n" ] } ], @@ -247,19 +247,19 @@ }, { "cell_type": "markdown", - "id": "c2d5b174", + "id": "a64f0f33", "metadata": {}, "source": [ "Let's add this information in a table and keep comparing it with a few possible improvements in future sections:\n", "\n", "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", - "| HF (baseline) | BF16 | 289 | 1 |" + "| HF (baseline) | BF16 | 315 | 1 |" ] }, { "cell_type": "markdown", - "id": "a7d436bf", + "id": "d9898383", "metadata": {}, "source": [ "## [Improvement 1] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `BF16`)\n", @@ -322,6 +322,7 @@ " normalization=\"RMSNorm\",\n", " activation=\"swiglu\",\n", " attn_input_format=\"bshd\",\n", + " num_gqa_groups=config.num_key_value_heads,\n", " )\n", " te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads)\n", " self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda()\n", @@ -339,10 +340,11 @@ "8. `fuse_qkv_params`: if set to True, TransformerLayer module exposes a single fused parameter for query-key-value. This enables optimizations such as QKV fusion without concatentations/splits and also enables the argument fuse_wgrad_accumulation.\n", "9. `normalization`: type of normalization applied. Default is `LayerNorm`.\n", "10. `activation`: type of activation used in the MLP block. Default is `gelu`.\n", - "11. `attn_input_format`: controls whether the dimensions of the intermediate hidden states is 'batch first' ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, `b` batch size, `h` the number of heads, `d` head size. Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules. \n", + "11. `attn_input_format`: controls whether the dimensions of the intermediate hidden states is 'batch first' ('bshd') or 'sequence first' ('sbhd'). `s` stands for the sequence length, `b` batch size, `h` the number of heads, `d` head size. Note that these formats are very closely related to the `qkv_format` in the `MultiHeadAttention` and `DotProductAttention` modules.\n", + "12. `num_gqa_groups`: number of GQA groups in the transformer layer. Grouped Query Attention is described in [this paper](https://arxiv.org/pdf/2305.13245.pdf). This only affects the keys and values, not the querys. GQA-1 is equivalent to Multi-Query Attention ([MQA](https://arxiv.org/pdf/1911.02150.pdf)), while GQA-H is equivalent to MultiHead Attention, i.e. `num_gqa_groups = num_attention_heads`.\n", "\n", "\n", - "Further, note that `RotaryPositionEmbedding` is defined as part of the TE's `TransformerLayer` itself since it expects this rope cache if RoPE is used in the model. \n", + "Further, note that `RotaryPositionEmbedding` is defined as part of the `TELlamaDecoderLayer` (wrapper around TE's `TransformerLayer`) itself since it expects this rope cache if RoPE is used in the model. \n", "\n", "Let's revisit how `LlamaDecoderLayer`s form the core of the decoder layer stack in HF's llama implementation:\n", "```\n", @@ -422,12 +424,12 @@ "\n", "```\n", "@contextmanager\n", - "def replace_decoder(te_decodder_cls):\n", + "def replace_decoder(te_decoder_cls):\n", " \"\"\"\n", " Replace `LlamaDecoderLayer` with custom `TELlamaDecoderLayer`.\n", " \"\"\"\n", " original_llama_decoder_cls = transformers.models.llama.modeling_llama.LlamaDecoderLayer\n", - " transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decodder_cls\n", + " transformers.models.llama.modeling_llama.LlamaDecoderLayer = te_decoder_cls\n", " try:\n", " yield\n", " finally:\n", @@ -446,7 +448,7 @@ " \"\"\"\n", "\n", " def __new__(cls, config: LlamaConfig):\n", - " with replace_decoder(te_decodder_cls=TELlamaDecoderLayer):\n", + " with replace_decoder(te_decoder_cls=TELlamaDecoderLayer):\n", " llama_for_causal_lm = LlamaForCausalLM(config)\n", " return llama_for_causal_lm\n", ".\n", @@ -530,7 +532,7 @@ { "cell_type": "code", "execution_count": 1, - "id": "48dc8935", + "id": "4974b738", "metadata": {}, "outputs": [ { @@ -538,7 +540,7 @@ "output_type": "stream", "text": [ "10 finetuning steps complete!\n", - "Average time taken per step: 242 milliseconds\n" + "Average time taken per step: 252 milliseconds\n" ] } ], @@ -570,20 +572,20 @@ }, { "cell_type": "markdown", - "id": "3c3d228a", + "id": "85c78c7f", "metadata": {}, "source": [ - "Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `LlamaDecoderLayer` gives a speedup of **19%** even when using only BF16 precision!\n", + "Compared to the \"baseline\" implementation, we see that using Transformer Engine's `TransformerLayer` in place of Huggging Face's `LlamaDecoderLayer` gives a speedup of **25%** even when using only BF16 precision!\n", "\n", "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", - "| HF (baseline) | BF16 | 289 | 1 |\n", - "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 242 | 1.19 |" + "| HF (baseline) | BF16 | 315 | 1 |\n", + "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 252 | 1.25 |" ] }, { "cell_type": "markdown", - "id": "b92d6792", + "id": "e2fb88e9", "metadata": {}, "source": [ "## [Improvement 2] Replace HF's `LlamaDecoderLayer` with TE's `TransformerLayer` (Precision: `FP8`)\n", @@ -608,8 +610,8 @@ }, { "cell_type": "code", - "execution_count": 2, - "id": "6bba7cc1", + "execution_count": 1, + "id": "8f2b752e", "metadata": {}, "outputs": [ { @@ -617,7 +619,7 @@ "output_type": "stream", "text": [ "10 finetuning steps complete!\n", - "Average time taken per step: 231 milliseconds\n" + "Average time taken per step: 226 milliseconds\n" ] } ], @@ -649,27 +651,27 @@ }, { "cell_type": "markdown", - "id": "602239d7", + "id": "67ec126c", "metadata": {}, "source": [ "| Models | Precision | Step Time (or ms per batch) | Speedup (over baseline) |\n", "|-------------------------------------------------------------|-----------|-----------------------------|-------------------------|\n", - "| HF (baseline) | BF16 | 289 | 1 |\n", - "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 242 | 1.19 |\n", - "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 231 | 1.25 |\n", + "| HF (baseline) | BF16 | 315 | 1 |\n", + "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | BF16 | 252 | 1.25 |\n", + "| TE (replace `LlamaDecoderLayer` with `TE.TransformerLayer`) | FP8 | 226 | 1.39 |\n", "\n", "\n", - "After turning on FP8 precision, we get even more speedup of **25%**!" + "After turning on FP8 precision, we get even more speedup of almost **40%**!" ] }, { "cell_type": "markdown", - "id": "372867d5", + "id": "41b80b0f", "metadata": {}, "source": [ "## Conclusion\n", "\n", - "Using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `LlamaDecoderLayer` provides speedup over Hugging Face's native Llama 2 implementation. This needs careful initializing of model such that the model weights (which are meant for `LlamaDecoderLayer`) are correctly mapped to their counterparts in TE's `TransformerLayer`. Even with `BF16` precision, `TransformerLayer` provides a speedup over the baseline implementation. With `FP8` precision, the speed up is even more pronounced!" + "Using `TransformerLayer` module from Transformer Engine as a substitute for Hugging Face's `LlamaDecoderLayer` provides a speedup over Hugging Face's native Llama 2 implementation. This needs careful initialization of the model such that the model weights (which are meant for `LlamaDecoderLayer`) are correctly mapped to their counterparts in TE's `TransformerLayer`. Even with `BF16` precision, `TransformerLayer` provides a speedup over the baseline implementation. With `FP8` precision, the speed up is even more pronounced!" ] } ], diff --git a/docs/examples/te_llama/utils.py b/docs/examples/te_llama/utils.py index 04abe39b6a..54b329f12b 100644 --- a/docs/examples/te_llama/utils.py +++ b/docs/examples/te_llama/utils.py @@ -26,7 +26,9 @@ def __init__(self): self.batch_size = 8 self.max_seq_length = 256 self.gradient_accumulation_steps = 1 + self.num_warmup_steps=5 self.num_training_steps=10 + hyperparams = HyperParameters() @@ -132,11 +134,9 @@ def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, optimizer.zero_grad() train_dataloader = enumerate(train_dataloader) - time_vals = [] - - for _ in range(hyperparams.num_training_steps): + # Warmup iters + for _ in range(hyperparams.num_warmup_steps): step, batch = next(train_dataloader) - start_time = time.time() with accelerator.accumulate(model): outputs = model(**batch) loss = outputs.loss @@ -146,15 +146,28 @@ def finetune_model(model, hyperparams, accelerator, train_dataloader, optimizer, lr_scheduler.step() optimizer.zero_grad() - end_time = time.time() - total_time = end_time - start_time - time_vals.append(total_time) + # Get the timers ready + start = torch.cuda.Event(enable_timing=True) + end = torch.cuda.Event(enable_timing=True) + torch.cuda.synchronize() + start.record() + # Training iters + for _ in range(hyperparams.num_training_steps): + step, batch = next(train_dataloader) + with accelerator.accumulate(model): + outputs = model(**batch) + loss = outputs.loss + total_loss += loss.detach().float() + accelerator.backward(loss) + optimizer.step() + lr_scheduler.step() + optimizer.zero_grad() + torch.cuda.synchronize() + end.record() accelerator.end_training() - # ignore the first couple of time vals - time_vals = time_vals[2:] - print(f"{hyperparams.num_training_steps} finetuning steps complete!\nAverage time taken per step: {(sum(time_vals)/len(time_vals)) * 1000:.0f} milliseconds") + print(f"{hyperparams.num_training_steps} finetuning steps complete!\nAverage time taken per step: {(start.elapsed_time(end)/hyperparams.num_training_steps):.0f} milliseconds") def restart_jupyter_notebook(): # Try restarting the Jupyter kernel From 59bfc17b7376debe6b0b442e246b296b2f58becf Mon Sep 17 00:00:00 2001 From: Kite0011 <50414751+Kite0011@users.noreply.github.com> Date: Thu, 21 Mar 2024 14:39:38 +0800 Subject: [PATCH 08/16] [Pytorch] Update context parallel softmax lse correction func (#716) [Pytorch] Update context parallel softmax lse correction func. Signed-off-by: kitefang Co-authored-by: kitefang --- transformer_engine/pytorch/attention.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 2f9ee988e0..26ab9f3283 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -483,9 +483,10 @@ def flash_attn_fwd_out_correction(out, out_per_step, softmax_lse, softmax_lse_pe @jit_fuser def flash_attn_fwd_softmax_lse_correction(softmax_lse, softmax_lse_per_step): """Merge softmax stats of each step in Attention with context parallelism""" - softmax_lse.exp_() - softmax_lse.add_(softmax_lse_per_step.to(torch.double).exp()) - softmax_lse.log_() + max_scale = torch.max(softmax_lse, softmax_lse_per_step) + min_scale = torch.min(softmax_lse, softmax_lse_per_step) + new_scale = max_scale + torch.log(1 + torch.exp(min_scale - max_scale)) + softmax_lse.copy_(new_scale) class AttnFuncWithCP(torch.autograd.Function): From b855656b20bc1cb3df4327fc017acdbd240100c9 Mon Sep 17 00:00:00 2001 From: Sangkug Lym Date: Thu, 21 Mar 2024 12:02:53 -0700 Subject: [PATCH 09/16] TP-RS overlap with send/recv ring-exchange (#724) * TP-RS overlap with send/recv Atomic GEMM based TP-RS overlap with send/recv Signed-off-by: Sangkug Lym Specify userbuffer overlap method of each overlap instance Signed-off-by: Sangkug Lym P2P TP-RS overlap with fp8 GEMM outputs Signed-off-by: Sangkug Lym Fix TP-RS overlap with send/recv Signed-off-by: Sangkug Lym * cleanup Signed-off-by: Sangkug Lym * cleanup Signed-off-by: Sangkug Lym * linting Signed-off-by: Sangkug Lym * fix typo Signed-off-by: Sangkug Lym --------- Signed-off-by: Sangkug Lym Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> --- transformer_engine/pytorch/attention.py | 18 +- .../pytorch/cpp_extensions/gemm.py | 30 +- .../pytorch/csrc/comm_gemm_overlap.h | 267 ++++++++++++++++-- .../pytorch/csrc/extensions/pybind.cpp | 26 +- .../pytorch/csrc/userbuffers/userbuffers.cu | 31 ++ .../pytorch/csrc/userbuffers/userbuffers.h | 4 + transformer_engine/pytorch/module/base.py | 43 ++- .../pytorch/module/layernorm_linear.py | 52 ++-- .../pytorch/module/layernorm_mlp.py | 137 ++++----- transformer_engine/pytorch/module/linear.py | 122 ++++---- transformer_engine/pytorch/transformer.py | 35 +-- 11 files changed, 497 insertions(+), 268 deletions(-) diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 26ab9f3283..924e2bb97d 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -3175,10 +3175,8 @@ def __init__( qkv_weight_interleaved: bool = True, ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, - ub_split_rs: bool = False, - ub_split_ag: bool = False, - ub_atomic_gemm_rs: bool = False, - ub_atomic_gemm_ag: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_ag: bool = False, bias: bool = True, normalization: str = "LayerNorm", device: Union[torch.device, str] = "cuda", @@ -3265,9 +3263,8 @@ def __init__( zero_centered_gamma=zero_centered_gamma, ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_dgrad=ub_bulk_dgrad, - ub_split_ag=ub_split_ag, + ub_overlap_ag=ub_overlap_ag, normalization=normalization, - ub_atomic_gemm_ag=ub_atomic_gemm_ag, ub_name="qkv", **common_gemm_kwargs, ) @@ -3297,9 +3294,8 @@ def __init__( zero_centered_gamma=zero_centered_gamma, ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_dgrad=ub_bulk_dgrad, - ub_split_ag=ub_split_ag, + ub_overlap_ag=ub_overlap_ag, normalization=normalization, - ub_atomic_gemm_ag=ub_atomic_gemm_ag, ub_name="qkv", **common_gemm_kwargs, ) @@ -3347,10 +3343,8 @@ def __init__( bias=bias, return_bias=return_bias, parallel_mode="row" if set_parallel_mode else None, - ub_split_rs=ub_split_rs, - ub_split_ag=ub_split_ag, - ub_atomic_gemm_rs=ub_atomic_gemm_rs, - ub_atomic_gemm_ag=ub_atomic_gemm_ag, + ub_overlap_rs=ub_overlap_rs, + ub_overlap_ag=ub_overlap_ag, ub_name="proj", **common_gemm_kwargs, ) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index 4ddab0e5a1..df571a0e6b 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -101,14 +101,14 @@ def fp8_gemm( empty_tensor if extra_output_tensor is None else extra_output_tensor ) args = tuple(args + (0, extra_output_tensor,)) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG: - fn = ub.split_overlap_ag + elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P: + fn = ub.split_overlap_ag_p2p extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor ) args = tuple(args + (extra_output_tensor,)) - elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG: - fn = ub.atomic_gemm_overlap_ag + elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P: + fn = ub.atomic_gemm_overlap_ag_p2p extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor ) @@ -119,12 +119,24 @@ def fp8_gemm( extra_output_tensor is not None ), 'SPLIT_PIPELINED_RS requires extra output tensor' args = tuple(args + (True, extra_output_tensor,)) + elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P: + fn = ub.split_overlap_rs_p2p + assert ( + extra_output_tensor is not None + ), 'SPLIT_PIPELINED_RS_P2P requires extra output tensor' + args = tuple(args + (extra_output_tensor,)) elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS: fn = ub.atomic_gemm_overlap_rs assert ( extra_output_tensor is not None ), 'ATOMIC_GEMM_RS requires extra output tensor' args = tuple(args + (True, extra_output_tensor,)) + elif ub_algo == tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P: + fn = ub.atomic_gemm_overlap_rs_p2p + assert ( + extra_output_tensor is not None + ), 'ATOMIC_GEMM_RS_P2P requires extra output tensor' + args = tuple(args + (extra_output_tensor,)) _ = fn(*args) return out, gelu_input @@ -217,8 +229,8 @@ def gemm( elif ub_algo == tex.UbufOverlapAlgo.BULK_OVERLAP_RS: fn = ub.bulk_overlap args = tuple(args + (0, empty_tensor)) - elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG: - fn = ub.split_overlap_ag + elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P: + fn = ub.split_overlap_ag_p2p extra_output_tensor = ( empty_tensor if extra_output_tensor is None else extra_output_tensor ) @@ -229,6 +241,12 @@ def gemm( extra_output_tensor is not None ), 'SPLIT_PIPELINED_RS requires extra output tensor' args = tuple(args + (False, extra_output_tensor,)) + elif ub_algo == tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P: + fn = ub.split_overlap_rs_p2p + assert ( + extra_output_tensor is not None + ), 'SPLIT_PIPELINED_RS_P2P requires extra output tensor' + args = tuple(args + (extra_output_tensor,)) _ = fn(*args) return out, grad_bias, gelu_input diff --git a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h index 5f8ccab334..817a3ef366 100644 --- a/transformer_engine/pytorch/csrc/comm_gemm_overlap.h +++ b/transformer_engine/pytorch/csrc/comm_gemm_overlap.h @@ -41,10 +41,12 @@ enum class COMM_TYPE { RS = 0, AG = 1 }; enum class UBOverlapAlgo { BULK_OVERLAP_AG = 0, BULK_OVERLAP_RS = 1, - SPLIT_PIPELINED_AG = 2, + SPLIT_PIPELINED_AG_P2P = 2, SPLIT_PIPELINED_RS = 3, - ATOMIC_GEMM_RS = 4, - ATOMIC_GEMM_AG = 5 + SPLIT_PIPELINED_RS_P2P = 4, + ATOMIC_GEMM_RS = 5, + ATOMIC_GEMM_AG_P2P = 6, + ATOMIC_GEMM_RS_P2P = 7 }; struct UbufBase { @@ -70,9 +72,10 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { int comm_sms; int cga_size; int use_ce; + bool _atomic_gemm; UbufCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm, int comm_cga_size, - int num_splits, bool set_sm_margin, int num_max_streams, + int num_splits, bool set_sm_margin, int num_max_streams, bool atomic_gemm, torch::Tensor empty_tensor) { // Initialize userbuf communicator if (!comm_created) { @@ -116,9 +119,12 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { _math_sms -= transformer_engine::getenv("NVTE_EXT_MARGIN_SM", 0); output_tensor = torch::Tensor(); - auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); - counter = torch::zeros({num_splits * 2}, counter_options); - counter.index_put_({Slice(None, num_splits)}, 1); + _atomic_gemm = atomic_gemm; + if (_atomic_gemm) { + auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); + counter = torch::zeros({num_splits * 2}, counter_options); + counter.index_put_({Slice(None, num_splits)}, 1); + } // CUDA event creation cudaEventCreateWithFlags(&_start_compute, 0); cudaEventCreateWithFlags(&_stop_compute, 0); @@ -519,12 +525,15 @@ struct UbufCommOverlap : torch::CustomClassHolder, UbufBase { output_tensor = torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); return output_tensor; } + + bool is_atomic_gemm() { return _atomic_gemm; } + bool is_p2p_overlap() { return false; } }; // UbufCommOverlap struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { int _tp_id; int _tp_size; - int _ub_reg; + int _ub_reg, _ub_reg2; int _next_rank, _prev_rank, _rank, _rank_round_tp; int _aggregate2; int _math_sms; @@ -533,18 +542,21 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { torch::Tensor _ubuf; torch::Tensor counter; torch::Tensor _empty_tensor; + torch::Tensor _ubuf_scale_inv; + bool _ubuf_scale_inv_initialized; std::vector _ubufs; at::cuda::CUDAStream _stream_send = at::cuda::getStreamFromPool(true); at::cuda::CUDAStream _stream_recv = at::cuda::getStreamFromPool(true); std::vector _stream_compute; - cudaEvent_t _start_compute, _stop_compute, _stop_send, _stop_recv; + cudaEvent_t _start_compute, _stop_compute, _start_comm, _stop_send, _stop_recv; int use_ce; int sms; int cga_size; + bool _atomic_gemm; UbufP2PCommOverlap(torch::Tensor sample, int rank, int tp_size, int num_comm_sm, int comm_cga_size, bool set_sm_margin, bool aggregate2, int num_max_streams, - torch::Tensor empty_tensor) { + bool is_reduce_scatter, bool atomic_gemm, torch::Tensor empty_tensor) { // Initialize userbuf communicator if (!comm_created) { if (rank == 0) { @@ -561,16 +573,25 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { // Create workspace tensor with userbuffer int ubuf_bytes = sample.numel() * sample.element_size(); int ubuf_chunk_bytes = ubuf_bytes / tp_size; + int num_ubuf_chunks = tp_size; + if (is_reduce_scatter) { + // GEMM + RS overlap: Allocate `2 x tp_size - 1` buffers to hold recieved GEMM chunk + // outputs for reduction at the end of the pipelining. + ubuf_bytes = static_cast(ubuf_bytes / tp_size * (tp_size * 2 - 1)); + num_ubuf_chunks = static_cast(tp_size * 2 - 1); + } _ub_reg = register_user_buffer_collective(reinterpret_cast(&_ubuf_ptr), ubuf_bytes, _ub_comm, true); if (rank == 0) { printf("!!! [UBP2P] Register UBuf %d\n", _ub_reg); } - _ubuf = torch::from_blob(_ubuf_ptr, {sample.size(0), sample.size(1)}, sample.options()); + + _ubuf = torch::from_blob( + _ubuf_ptr, {sample.size(0) / tp_size * num_ubuf_chunks, sample.size(1)}, sample.options()); // Create tensor chunks for easy management char *ubuf_byte_ptr = reinterpret_cast(_ubuf.data_ptr()); - for (int i = 0; i < tp_size; i++) { + for (int i = 0; i < num_ubuf_chunks; i++) { torch::Tensor ubuf_chunk = torch::from_blob( ubuf_byte_ptr, {sample.size(0) / tp_size, sample.size(1)}, sample.options()); _ubufs.push_back(ubuf_chunk); @@ -599,30 +620,37 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { _rank_round_tp = (rank / tp_size) * tp_size; _next_rank = (tp_size + rank + 1) % tp_size + _rank_round_tp; _prev_rank = (tp_size + rank + -1) % tp_size + _rank_round_tp; + _ubuf_scale_inv_initialized = false; - auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); - counter = torch::zeros({tp_size * 2}, counter_options); - counter.index_put_({Slice(None, tp_size)}, 1); - _self_chunk_id = _tp_id; - - const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC"); - if (rank == 0 && env_p != nullptr) { - if (env_p[0] == '1') { - printf("!!userbuffers_sendrecv_atomic\n"); - } else if (env_p[0] == '2') { - printf("!!userbuffers_sendrecv_multiatomic\n"); - } else if (env_p[0] == '3') { - printf("!!userbuffers_sendrecv_multiatomic_shuffle\n"); - _self_chunk_id = 0; - } else { - printf("!!userbuffers_sendrecv\n"); + _atomic_gemm = atomic_gemm; + if (_atomic_gemm) { + auto counter_options = torch::TensorOptions().dtype(torch::kInt32).device(torch::kCUDA); + counter = torch::zeros({tp_size * 2}, counter_options); + counter.index_put_({Slice(None, tp_size)}, 1); + _self_chunk_id = _tp_id; + + if (!is_reduce_scatter) { + const char *env_p = std::getenv("NVTE_AG_P2P_ATOMIC"); + if (rank == 0 && env_p != nullptr) { + if (env_p[0] == '1') { + printf("!!userbuffers_sendrecv_atomic\n"); + } else if (env_p[0] == '2') { + printf("!!userbuffers_sendrecv_multiatomic\n"); + } else if (env_p[0] == '3') { + printf("!!userbuffers_sendrecv_multiatomic_shuffle\n"); + _self_chunk_id = 0; + } else { + printf("!!userbuffers_sendrecv\n"); + } + } + counter.index_put_({_self_chunk_id}, 0); } } - counter.index_put_({_self_chunk_id}, 0); // CUDA event creation cudaEventCreateWithFlags(&_start_compute, 0); cudaEventCreateWithFlags(&_stop_compute, 0); + cudaEventCreateWithFlags(&_start_comm, 0); cudaEventCreateWithFlags(&_stop_send, 0); cudaEventCreateWithFlags(&_stop_recv, 0); } @@ -758,7 +786,8 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)stream_main, _stop_compute, 0)); return D; - } // split_overlap_ag + } // atomic_gemm_overlap_ag + /* ** Split AllGather + GEMM using P2P communication ** This function assumes the input_b is pre-copied to _ubufs[rank_id]. This is @@ -948,6 +977,174 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { return D; } // split_overlap_ag +/* + ** Split ReduceScatter + GEMM using P2P communication + */ + void atomic_gemm_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + at::Tensor rs_output) { + _ub_comm->use_ce = use_ce; + _ub_comm->sms = sms; + _ub_comm->cga_size = cga_size; + int k = A.size(1); + int n = B.size(0); + + // Get communication and GEMM input chunk sizes + int n_chunk = n / _tp_size; + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const int input_b_chunk_bytes = n_chunk * k * B.element_size(); + + // Get input and workspace data pointers + char *input_b_ptr = reinterpret_cast(B.data_ptr()); + char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); + int *counter_ptr = reinterpret_cast(counter.data_ptr()); + int workspace_size_chunk = workspaceSize / _stream_compute.size(); + + if (A_scale_inverse.numel()) + A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + + if (B_scale_inverse.numel()) + B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + + // Catch up the main stream + at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t)_stream_recv, _start_compute, 0)); + + // Atomic GEMM + torch::Tensor workspace_chunk = + torch::from_blob(workspace_ptr, {workspace_size_chunk}, workspace.options()); + te_atomic_gemm(A, A_scale_inverse, A_type, transa, B, B_scale_inverse, B_type, transb, + _ubuf, D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, + _math_sms, 0, _tp_size, true, counter); + + // P2P communication chunk + for (int i = 1; i < _tp_size; i++) { + int send_chunk_id = i - 1; + int recv_chunk_id = send_chunk_id + _tp_size; + int send_offset = comm_bytes * send_chunk_id; + int recv_offset = comm_bytes * recv_chunk_id; + int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; + int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + + consumer(counter_ptr, send_chunk_id, (cudaStream_t)_stream_recv); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, + _ub_comm, send_rank, (cudaStream_t) _stream_recv); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, + _ub_comm, recv_rank, (cudaStream_t) _stream_recv); + } + CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t) _stream_recv)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) stream_main, _stop_recv, 0)); + + // Reduce GEMM output chunks + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); + torch::Tensor reduce_buf = torch::from_blob( + reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); + torch::sum_out(rs_output, reduce_buf, 0); + } + + /* + ** Split ReduceScatter + GEMM using P2P communication + */ + void split_overlap_rs(at::Tensor A, at::Tensor A_scale_inverse, int64_t A_fp8_tensor, + transformer_engine::DType A_type, bool transa, at::Tensor B, + at::Tensor B_scale_inverse, int64_t B_fp8_tensor, + transformer_engine::DType B_type, bool transb, at::Tensor D, + at::Tensor D_scale, transformer_engine::DType D_type, at::Tensor D_amax, + at::Tensor bias, transformer_engine::DType bias_type, + at::Tensor pre_gelu_out, bool grad, at::Tensor workspace, + size_t workspaceSize, bool accumulate, bool use_split_accumulator, + at::Tensor rs_output) { + _ub_comm->use_ce = use_ce; + _ub_comm->sms = sms; + _ub_comm->cga_size = cga_size; + int k = A.size(1); + int n = B.size(0); + + // Get communication and GEMM input chunk sizes + int n_chunk = n / _tp_size; + const int comm_bytes = _ubufs[0].numel() * _ubufs[0].element_size(); + const int input_b_chunk_bytes = n_chunk * k * B.element_size(); + + // Get input and workspace data pointers + char *input_b_ptr = reinterpret_cast(B.data_ptr()); + char *workspace_ptr = reinterpret_cast(workspace.data_ptr()); + int workspace_size_chunk = workspaceSize / _stream_compute.size(); + + if (A_scale_inverse.numel()) + A_scale_inverse = A_scale_inverse[A_fp8_tensor]; + + if (B_scale_inverse.numel()) + B_scale_inverse = B_scale_inverse[B_fp8_tensor]; + + // Catch up the main stream + at::cuda::CUDAStream stream_main = at::cuda::getDefaultCUDAStream(); + CHECK_CUDA(cudaEventRecord(_start_compute, (cudaStream_t)stream_main)); + for (int i = 0; i < _stream_compute.size(); i++) { + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_compute[i], _start_compute, 0)); + } + + // GEMM and send/recv chunks + for (int i = 0; i < _tp_size; i++) { + // GEMM chunk + int input_b_chunk_id = (_tp_id + i + 1) % _tp_size; + char* input_b_chunk_ptr = input_b_ptr + (input_b_chunk_id * input_b_chunk_bytes); + torch::Tensor input_b_chunk = torch::from_blob(input_b_chunk_ptr, {n_chunk, k}, B.options()); + // Store the last GEMM chunk output to the recieve buffer. + torch::Tensor workspace_chunk = torch::from_blob( + workspace_ptr + (i % _stream_compute.size()) * workspace_size_chunk, + {workspace_size_chunk}, workspace.options()); + if (i == _tp_size - 1) { + at::cuda::setCurrentCUDAStream(stream_main); + } else { + at::cuda::setCurrentCUDAStream(_stream_compute[i % _stream_compute.size()]); + } + te_gemm(A, A_scale_inverse, A_type, transa, input_b_chunk, B_scale_inverse, B_type, transb, + _ubufs[i], D_scale, D_type, D_amax, bias, bias_type, pre_gelu_out, grad, + workspace_chunk, workspace_size_chunk, accumulate, use_split_accumulator, + _math_sms); + + if (i > 0) { + // P2P communication chunk + int send_offset = comm_bytes * (i - 1); + int recv_offset = comm_bytes * (i - 1 + _tp_size); + int send_rank = (_tp_id + i) % _tp_size + _rank_round_tp; + int recv_rank = (_tp_size + _tp_id - i) % _tp_size + _rank_round_tp; + CHECK_CUDA(cudaEventRecord( + _start_comm, (cudaStream_t) _stream_compute[(i - 1) % _stream_compute.size()])); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_send, _start_comm, 0)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) _stream_recv, _start_comm, 0)); + userbuffers_send(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, + _ub_comm, send_rank, (cudaStream_t) _stream_send); + userbuffers_recv(_ub_reg, send_offset, _ub_reg, recv_offset, comm_bytes, + _ub_comm, recv_rank, (cudaStream_t) _stream_recv); + } + } + CHECK_CUDA(cudaEventRecord(_stop_recv, (cudaStream_t) _stream_recv)); + CHECK_CUDA(cudaStreamWaitEvent((cudaStream_t) stream_main, _stop_recv, 0)); + + // Reduce GEMM output chunks + char *reduce_buf_ptr = reinterpret_cast(_ubufs[_tp_size - 1].data_ptr()); + if (_ubuf.element_size() == 1 && rs_output.element_size() == 2) { + assert(_ubuf_scale_inv_initialized); + float *d_scale_inv_ptr = reinterpret_cast(_ubuf_scale_inv.data_ptr()); + char *rs_output_ptr = reinterpret_cast(rs_output.data_ptr()); + reduce_fp8_in_bf16_out<__nv_fp8_e4m3>(reduce_buf_ptr, rs_output_ptr, d_scale_inv_ptr, + _tp_size, _ubufs[0].numel(), (cudaStream_t) stream_main); + } else { + torch::Tensor reduce_buf = torch::from_blob( + reduce_buf_ptr, {_tp_size, _ubufs[0].size(0), _ubufs[0].size(1)}, _ubuf.options()); + torch::sum_out(rs_output, reduce_buf, 0); + } + } + /* ** Copy input to _ubufs[0] */ @@ -970,6 +1167,7 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { (cudaStream_t)stream_main)); } } + torch::Tensor get_ubuf_output(int comm_type) { char *ubuf_wt_ptr = reinterpret_cast(_ubuf.data_ptr()); COMM_TYPE _comm_type = static_cast(comm_type); @@ -981,6 +1179,15 @@ struct UbufP2PCommOverlap : torch::CustomClassHolder, UbufBase { int output_c_dim1 = _ubuf.size(1); return torch::from_blob(ubuf_wt_ptr, {output_c_dim0, output_c_dim1}, _ubuf.options()); } + + void set_ubuf_scale_inv(const torch::Tensor &scale_inv) { + _ubuf_scale_inv = scale_inv; + _ubuf_scale_inv_initialized = true; + } + + bool is_fp8_ubuf() { return (_ubuf.element_size() == 1); } + bool is_atomic_gemm() { return _atomic_gemm; } + bool is_p2p_overlap() { return true; } }; // UbufP2PCommOverlap } // namespace ubuf diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index b5aa10b150..328bf1dcb4 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -109,26 +109,36 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { .value("BULK_OVERLAP_AG", ubuf::UBOverlapAlgo::BULK_OVERLAP_AG) .value("BULK_OVERLAP_RS", ubuf::UBOverlapAlgo::BULK_OVERLAP_RS) .value("SPLIT_PIPELINED_RS", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS) - .value("SPLIT_PIPELINED_AG", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG) + .value("SPLIT_PIPELINED_RS_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_RS_P2P) + .value("SPLIT_PIPELINED_AG_P2P", ubuf::UBOverlapAlgo::SPLIT_PIPELINED_AG_P2P) .value("ATOMIC_GEMM_RS", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS) - .value("ATOMIC_GEMM_AG", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG); + .value("ATOMIC_GEMM_AG_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_AG_P2P) + .value("ATOMIC_GEMM_RS_P2P", ubuf::UBOverlapAlgo::ATOMIC_GEMM_RS_P2P); py::class_(m, "UbufCommOverlap") - .def(py::init()) + .def(py::init()) .def("bulk_overlap", &ubuf::UbufCommOverlap::bulk_overlap) .def("split_overlap_rs", &ubuf::UbufCommOverlap::split_overlap_rs) .def("set_ubuf_scale_inv", &ubuf::UbufCommOverlap::set_ubuf_scale_inv) .def("atomic_gemm_overlap_rs", &ubuf::UbufCommOverlap::atomic_gemm_overlap_rs) .def("is_fp8_ubuf", &ubuf::UbufCommOverlap::is_fp8_ubuf) .def("copy_input_to_ubuf", &ubuf::UbufCommOverlap::copy_input_to_ubuf) - .def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output); + .def("get_ubuf_output", &ubuf::UbufCommOverlap::get_ubuf_output) + .def("is_atomic_gemm", &ubuf::UbufCommOverlap::is_atomic_gemm) + .def("is_p2p_overlap", &ubuf::UbufCommOverlap::is_p2p_overlap); py::class_(m, "UbufP2PCommOverlap") - .def(py::init()) - .def("split_overlap_ag", &ubuf::UbufP2PCommOverlap::split_overlap_ag) - .def("atomic_gemm_overlap_ag", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag) + .def(py::init()) + .def("split_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_ag) + .def("split_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::split_overlap_rs) + .def("atomic_gemm_overlap_ag_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_ag) + .def("atomic_gemm_overlap_rs_p2p", &ubuf::UbufP2PCommOverlap::atomic_gemm_overlap_rs) .def("copy_input_to_ubuf", &ubuf::UbufP2PCommOverlap::copy_input_to_ubuf) - .def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output); + .def("get_ubuf_output", &ubuf::UbufP2PCommOverlap::get_ubuf_output) + .def("is_fp8_ubuf", &ubuf::UbufP2PCommOverlap::is_fp8_ubuf) + .def("is_atomic_gemm", &ubuf::UbufP2PCommOverlap::is_atomic_gemm) + .def("is_p2p_overlap", &ubuf::UbufP2PCommOverlap::is_p2p_overlap) + .def("set_ubuf_scale_inv", &ubuf::UbufP2PCommOverlap::set_ubuf_scale_inv); #else // NVTE_WITH_USERBUFFERS m.def("UbufOverlapAlgo", &placeholder, "Dummy function for python side annotations"); m.def("UbufCommOverlap", &placeholder, "Dummy function for python side annotations"); diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu index 76e9453efc..0ec89b0bb7 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.cu @@ -3666,3 +3666,34 @@ void consumer(void *atomic_ptr, int chunk_i, cudaStream_t stream) { dim3 grid(1); consumer_kernel<<>>(atomic_ptr, chunk_i); } + +template +__global__ void __launch_bounds__(MAX_THREADS / 4) +reduce_fp8_in_bf16_out_cuda(void *inputs, void *output, const float *scale, + const int num_inputs, const int input_size) { + const size_t tid = threadIdx.x + blockDim.x * blockIdx.x; + fp8type *inputs_fp8 = reinterpret_cast(inputs); + float accum_buf = static_cast(inputs_fp8[tid]) * (*scale); + #pragma unroll + for (int i = 1; i < num_inputs; i++) { + accum_buf += static_cast(inputs_fp8[tid + input_size * i]) * (*scale); + } + half *output_half = reinterpret_cast(output); + output_half[tid] = (half) accum_buf; +} + +template +void reduce_fp8_in_bf16_out(void *inputs, void *output, float *scale, int num_inputs, + int input_size, cudaStream_t stream) { + size_t num_threads = MAX_THREADS / 4; + size_t num_blocks = (input_size +num_threads - 1) / num_threads; + dim3 block(num_threads); + dim3 grid(num_blocks); + reduce_fp8_in_bf16_out_cuda<<>>( + inputs, output, scale, num_inputs, input_size); +} + +template void reduce_fp8_in_bf16_out<__nv_fp8_e4m3>( + void *inputs, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream); +template void reduce_fp8_in_bf16_out<__nv_fp8_e5m2>( + void *inputs, void *output, float *scale, int num_inputs, int input_size, cudaStream_t stream); diff --git a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h index 2d030a1409..407f9479c3 100644 --- a/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h +++ b/transformer_engine/pytorch/csrc/userbuffers/userbuffers.h @@ -305,4 +305,8 @@ void userbuffers_alltoall_recv(communicator *comm, cudaStream_t stream = 0); void destroy_communicator(communicator *comm); +template +void reduce_fp8_in_bf16_out(void *input, void *output, float *scale, int num_inputs, + int input_size, cudaStream_t stream); + #endif // TRANSFORMER_ENGINE_USERBUFFERS_H_ diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index ab24789549..59e5949e06 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -129,13 +129,14 @@ def initialize_ub( "qkv_fprop", "qkv_dgrad", "proj_dgrad", "fc1_fprop", "fc1_dgrad", "fc2_dgrad" ] if bool(int(os.getenv("NVTE_UB_FP8_RS", "0"))): - fp8_buf.append ("proj_fprop") + fp8_buf += ["proj_fprop", "fc2_fprop"] # Default overlap methods for layers methods = { "ring_exchange":["qkv_fprop", "fc1_fprop", "proj_dgrad", "fc2_dgrad"], "pipeline":["proj_fprop", "fc2_fprop"], "bulk":["qkv_dgrad", "qkv_wgrad", "fc1_dgrad", "fc1_wgrad"], } + layers_reduce_scatter_overlap = ["proj_fprop", "fc2_fprop", "qkv_wgrad", "fc1_wgrad"] def get_method(name): for method, names in methods.items(): @@ -151,7 +152,28 @@ def add_ub( set_sm_margin: int = 0, num_splits: int = 4, aggregate: int = 0, + atomic_gemm: int = 0, + is_reduce_scatter: int = 0, ) -> None: + if atomic_gemm: + warnings.warn( + "Atomic GEMM uses a beta API from cublas and is not tested for all use cases." + ) + assert use_fp8, "Atomic GEMM overlap supported only for FP8 GEMM." + if is_reduce_scatter and method == "ring_exchange": + raise ValueError( + "Atomic GEMM is not supported for ReduceScatter with `ring_exchange` method." + ) + if method == 'bulk': + warnings.warn( + "Atoimic GEMM not is supported for a bulk overlap." + "Defaulting to `atomic_gemm=False`." + ) + atomic_gemm = 0 + if not is_reduce_scatter and method == 'pipeline': + raise ValueError( + "`pipeline` overlap method is not supported for AllGather." + ) sample_buffer = torch.empty( shape, dtype=torch.uint8 if (use_fp8 and name in fp8_buf) else dtype, @@ -166,6 +188,8 @@ def add_ub( set_sm_margin, # Set SM margin aggregate, # Aggregate 2X GEMM chunks _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams + is_reduce_scatter, # overlap with reduce scatter + atomic_gemm, # use a single GEMM with atomic-counters torch.Tensor(), # empty tensor to pass to counters ) else: @@ -178,6 +202,7 @@ def add_ub( num_splits, # Number of communication splits set_sm_margin, # Set SM margin _NUM_MAX_UB_STREAMS, # Max concurrent GEMM streams + atomic_gemm, # use a single GEMM with atomic-counters torch.Tensor(), # empty tensor to pass to counters ) _ub_communicators[name] = ub_obj @@ -191,6 +216,8 @@ def add_ub( num_splits = ub_cfg["num_splits"] if "num_splits" in ub_cfg else 0 set_sm_margin = ub_cfg["set_sm_margin"] if "set_sm_margin" in ub_cfg else 0 aggregate = ub_cfg["aggregate"] if "aggregate" in ub_cfg else 0 + atomic_gemm = ub_cfg["atomic_gemm"] if "atomic_gemm" in ub_cfg else 0 + is_reduce_scatter = 1 if name in layers_reduce_scatter_overlap else 0 add_ub( name, method, @@ -198,7 +225,9 @@ def add_ub( cga_size, set_sm_margin, num_splits, - aggregate + aggregate, + atomic_gemm, + is_reduce_scatter, ) else: method = get_method(name) @@ -632,12 +661,10 @@ def grad_output_preprocess( grad_output_mat = grad_output.view((-1, grad_output.shape[-1])) gather_grad_output = row_parallel_mode and ctx.sequence_parallel - if gather_grad_output: - ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag # No-FP8 case: bgrad is fused with wgrad for this case. if not ctx.fp8: if gather_grad_output: - if not ub_overlap_ag: + if not ctx.ub_overlap_ag: grad_output_mat, _ = gather_along_first_dim( grad_output_mat, ctx.tp_group ) @@ -656,7 +683,7 @@ def grad_output_preprocess( and ctx.fp8_meta["recipe"].override_linear_precision.wgrad ): assert ( - not ub_overlap_ag + not ctx.ub_overlap_ag ), "override_linear_precision.wgrad not supported with UB AG overlap" grad_output_mat, _ = gather_along_first_dim(grad_output_mat, ctx.tp_group) # FP8 case with gather: unfused bgrad, cast, transpose for efficient gather @@ -665,7 +692,7 @@ def grad_output_preprocess( grad_bias = grad_output_mat.sum(dim=0) else: grad_bias = None - if ub_overlap_ag: + if ctx.ub_overlap_ag: grad_output_c = ctx.ub_obj_gradout.get_ubuf_output(0) else: grad_output_c = torch.empty_like(grad_output_mat, dtype=torch.uint8) @@ -676,7 +703,7 @@ def grad_output_preprocess( fp8_dtype_backward, out=grad_output_c, ) - if not ub_overlap_ag: + if not ctx.ub_overlap_ag: grad_output_c, _ = gather_along_first_dim(grad_output_c, ctx.tp_group) grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) else: diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index eecd908e51..3711d9898f 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -86,8 +86,7 @@ def forward( primary_weights_in_fp8: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, - ub_split_ag: bool, - ub_atomic_gemm_ag: bool, + ub_overlap_ag: bool, ub_name: str, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible @@ -106,12 +105,11 @@ def forward( if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) - if ub_split_ag or ub_atomic_gemm_ag: + if ub_overlap_ag: tp_world_size = get_distributed_world_size(tp_group) if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: - ub_split_ag = False - ub_atomic_gemm_ag = False - if ub_split_ag or ub_atomic_gemm_ag: + ub_overlap_ag = False + if ub_overlap_ag: dim_size = list(inputmat.size()) dim_size[0] = dim_size[0] * tp_world_size ub_obj_lnout = get_ub(ub_name+"_fprop") @@ -119,8 +117,6 @@ def forward( else: ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) - if ub_atomic_gemm_ag: - assert fp8, "AtomicGemm overlap supported only for FP8 GEMM." fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -138,9 +134,13 @@ def forward( # Column Parallel Linear ln_out_gathered = False - if ub_split_ag or ub_atomic_gemm_ag: + if ub_overlap_ag: ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out = torch.empty_like(ln_out) + if ub_obj_lnout.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P elif parallel_mode == "column" and sequence_parallel: ln_out_gathered = True ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) @@ -201,8 +201,6 @@ def forward( ) weight_t_fp8 = None - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo out, _ = tex.fp8_gemm( weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, @@ -217,9 +215,9 @@ def forward( bias=bias, use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, - ub_algo=ub_algo, - ub=ub_obj_lnout if (ub_split_ag or ub_atomic_gemm_ag) else None, - extra_output_tensor=ln_out if (ub_split_ag or ub_atomic_gemm_ag) else None, + ub_algo=ub_algo if ub_overlap_ag else None, + ub=ub_obj_lnout if ub_overlap_ag else None, + extra_output_tensor=ln_out if ub_overlap_ag else None, ) else: # Cast for native AMP @@ -243,9 +241,9 @@ def forward( get_workspace(), bias=bias, use_bias=use_bias, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, - ub=ub_obj_lnout if ub_split_ag else None, - extra_output_tensor=ln_out if ub_split_ag else None, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, + ub=ub_obj_lnout if ub_overlap_ag else None, + extra_output_tensor=ln_out if ub_overlap_ag else None, ) if is_grad_enabled: @@ -624,7 +622,6 @@ def backward( None, None, None, - None, ) @@ -737,8 +734,7 @@ def __init__( device: Union[torch.device, str] = "cuda", ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, - ub_split_ag: bool = False, - ub_atomic_gemm_ag: bool = False, + ub_overlap_ag: bool = False, ub_name: Optional[str] = None, ) -> None: super().__init__() @@ -758,23 +754,16 @@ def __init__( self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad - self.ub_split_ag = ub_split_ag - self.ub_atomic_gemm_ag = ub_atomic_gemm_ag - if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_split_ag]): + self.ub_overlap_ag = ub_overlap_ag + if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag]): assert ub_name is not None, "Userbuffer name [string] is not set." self.ub_name = ub_name - - if ub_bulk_wgrad or ub_bulk_dgrad or ub_split_ag or ub_atomic_gemm_ag: + if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_ag]): assert ( tex.userbuf_comm_available() ), "Userbuffer communication backend not available." - if ub_atomic_gemm_ag: - warnings.warn( - "Atomic gemm uses a beta API from cublas and is not tested for all use cases." - ) - if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -1098,8 +1087,7 @@ def forward( self.primary_weights_in_fp8, self.ub_bulk_wgrad, self.ub_bulk_dgrad, - self.ub_split_ag, - self.ub_atomic_gemm_ag, + self.ub_overlap_ag, self.ub_name, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bf9e6fe558..7d86658260 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -117,10 +117,8 @@ def forward( primary_weights_in_fp8: bool, ub_bulk_wgrad: bool, ub_bulk_dgrad: bool, - ub_split_rs: bool, - ub_atomic_gemm_rs: bool, - ub_split_ag: bool, - ub_atomic_gemm_ag: bool, + ub_overlap_rs: bool, + ub_overlap_ag: bool, gemm_gelu_fusion: bool, ) -> Union[Tuple[torch.Tensor, ...], torch.Tensor]: # Make sure input dimensions are compatible @@ -142,25 +140,17 @@ def forward( if ln_bias is not None: ln_bias = cast_if_needed(ln_bias, activation_dtype) - if ub_split_ag or ub_atomic_gemm_ag: - tp_world_size = get_distributed_world_size(tp_group) + tp_world_size = get_distributed_world_size(tp_group) + if ub_overlap_ag: if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: - ub_split_ag = False - ub_atomic_gemm_ag = False - ub_overlap_ag = ub_split_ag or ub_atomic_gemm_ag + ub_overlap_ag = False if ub_overlap_ag: ub_obj_lnout = get_ub("fc1_fprop") ln_out = ub_obj_lnout.get_ubuf_output(0) else: ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) - if ub_split_rs or ub_atomic_gemm_rs: - tp_world_size = get_distributed_world_size(tp_group) - if tp_world_size == 1: - ub_split_rs = False - ub_atomic_gemm_rs = False - if ub_atomic_gemm_rs or ub_atomic_gemm_ag: - assert fp8, "AtomicGemm overlap supported only for FP8 GEMM." + ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs fp8_dtype_forward = get_fp8_te_dtype(fp8_meta["recipe"], fprop_tensor=True) @@ -181,6 +171,10 @@ def forward( if ub_overlap_ag: ln_out_total = ub_obj_lnout.get_ubuf_output(1) ln_out = torch.empty_like(ln_out) + if ub_obj_lnout.is_atomic_gemm(): + ub_algo_ag = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo_ag = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P elif set_parallel_mode and sequence_parallel: ln_out_gathered = True ln_out_total, _ = gather_along_first_dim(ln_out, tp_group) @@ -267,9 +261,6 @@ def forward( ) fc2_weight_t_fp8 = None - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ub_atomic_gemm_ag else ub_algo - # Perform FP8 GEMM fp8_gemm_args = [ fc1_weight_fp8._data, @@ -287,7 +278,7 @@ def forward( bias=fc1_bias, use_bias=use_fc1_bias, use_split_accumulator=_2X_ACC_FPROP, - ub_algo=ub_algo, + ub_algo=ub_algo_ag if ub_overlap_ag else None, ub=ub_obj_lnout if ub_overlap_ag else None, extra_output_tensor=ln_out if ub_overlap_ag else None, ) @@ -321,13 +312,23 @@ def forward( fc2_out_index, fc2_meta_tensor, fc2_te_type, out_type = ( None, None, None, activation_dtype) - if ub_split_rs or ub_atomic_gemm_rs: + if ub_overlap_rs: ub_obj_fc2out = get_ub("fc2_fprop") fc2_out = ub_obj_fc2out.get_ubuf_output(1) dim_size = list(gelu_out.size()) dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = fc2_weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) + if ub_obj_fc2out.is_p2p_overlap(): + if ub_obj_fc2out.is_atomic_gemm(): + ub_algo_rs = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + else: + ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + if ub_obj_fc2out.is_atomic_gemm(): + ub_algo_rs = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + else: + ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_obj_fc2out.is_fp8_ubuf(): fc2_out_index = tex.FP8FwdTensors.GEMM2_OUTPUT @@ -340,8 +341,6 @@ def forward( dim_size[1] = fc2_weight.size(0) fc2_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) - ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo _ = tex.fp8_gemm( fc2_weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, @@ -357,9 +356,9 @@ def forward( use_bias=use_fc2_bias, use_split_accumulator=_2X_ACC_FPROP, out=fc2_out, - ub_algo=ub_algo, - ub=ub_obj_fc2out if ub_split_rs or ub_atomic_gemm_rs else None, - extra_output_tensor=rs_out if ub_split_rs or ub_atomic_gemm_rs else None, + ub_algo=ub_algo_rs if ub_overlap_rs else None, + ub=ub_obj_fc2out if ub_overlap_rs else None, + extra_output_tensor=rs_out if ub_overlap_rs else None, out_index=fc2_out_index, fp8_meta_tensor = fc2_meta_tensor, D_dtype = fc2_te_type, @@ -395,9 +394,9 @@ def forward( bias=fc1_bias, use_bias=(not bias_gelu_nvfusion) and use_fc1_bias, gelu=not bias_gelu_nvfusion and (activation == 'gelu'), - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ub_split_ag else None, - ub=ub_obj_lnout if ub_split_ag else None, - extra_output_tensor=ln_out if ub_split_ag else None, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P if ub_overlap_ag else None, + ub=ub_obj_lnout if ub_overlap_ag else None, + extra_output_tensor=ln_out if ub_overlap_ag else None, ) if not is_grad_enabled: clear_tensor_data(ln_out_total) @@ -427,13 +426,17 @@ def forward( fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM2_WEIGHT] = \ torch.max(-amin, amax).float() - if ub_split_rs: + if ub_overlap_rs: ub_obj_fc2out = get_ub("fc2_fprop") fc2_out = ub_obj_fc2out.get_ubuf_output(1) dim_size = list(gelu_out.size()) dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = fc2_weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=gelu_out.device) + if ub_obj_fc2out.is_p2p_overlap(): + ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + ub_algo_rs = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS else: dim_size = list(gelu_out.size()) dim_size[1] = fc2_weight.size(0) @@ -446,9 +449,9 @@ def forward( bias=fc2_bias, use_bias=use_fc2_bias, out=fc2_out, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, - ub=ub_obj_fc2out if ub_split_rs else None, - extra_output_tensor=rs_out if ub_split_rs else None, + ub_algo=ub_algo_rs if ub_overlap_rs else None, + ub=ub_obj_fc2out if ub_overlap_rs else None, + extra_output_tensor=rs_out if ub_overlap_rs else None, ) if not is_grad_enabled: clear_tensor_data(gelu_out) @@ -515,13 +518,12 @@ def forward( ctx.zero_centered_gamma = zero_centered_gamma ctx.ub_bulk_wgrad = ub_bulk_wgrad ctx.ub_bulk_dgrad = ub_bulk_dgrad - ctx.ub_split_ag = ub_split_ag - ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag + ctx.ub_overlap_ag = ub_overlap_ag ctx.requires_dgrad = inp.requires_grad ctx.normalization = normalization # Row Parallel Linear - if ub_split_rs or ub_atomic_gemm_rs: + if ub_overlap_rs: fc2_out = rs_out elif set_parallel_mode and sequence_parallel: fc2_out, _ = reduce_scatter_along_first_dim(fc2_out, tp_group) @@ -590,18 +592,19 @@ def backward( dim_size[0] = dim_size[0] * tp_world_size ub_obj_lnout = get_ub("fc1_dgrad") ub_obj_lnout.copy_input_to_ubuf(ln_out, 1) - ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag - if ub_overlap_ag: + if ctx.ub_overlap_ag: tp_world_size = get_distributed_world_size(ctx.tp_group) if tp_world_size == 1: - ctx.ub_split_ag = False ctx.ub_overlap_ag = False - ub_overlap_ag = ctx.ub_split_ag or ctx.ub_atomic_gemm_ag - if ub_overlap_ag: + if ctx.ub_overlap_ag: dim_size = list(grad_outputs[0].size()) dim_size[0] = dim_size[0] * tp_world_size ctx.ub_obj_gradout = get_ub("fc2_dgrad") + if ctx.ub_obj_gradout.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P ctx.use_bias = ctx.use_fc2_bias # For grad_output_preprocess ( @@ -645,8 +648,6 @@ def backward( ctx.fp8_meta["recipe"], fprop_tensor=False ) - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo # FC2 DGRAD; Unconditional fc2_dgrad, _ = tex.fp8_gemm( fc2_weight_t_fp8._data, @@ -660,10 +661,10 @@ def backward( ctx.activation_dtype, get_workspace(), use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo, - ub=ctx.ub_obj_gradout if ub_overlap_ag else None, + ub_algo=ub_algo if ctx.ub_overlap_ag else None, + ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, ) - if ub_overlap_ag: + if ctx.ub_overlap_ag: grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) clear_tensor_data(grad_output_c) @@ -801,8 +802,9 @@ def backward( gelu=(not ctx.bias_gelu_nvfusion) and (ctx.activation == 'gelu'), grad=True, gelu_input=fc1_out, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, - ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P \ + if ctx.ub_overlap_ag else None, + ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, ) # FC2 WGRAD @@ -1070,8 +1072,6 @@ def backward( None, None, None, - None, - None, ) @@ -1194,10 +1194,8 @@ def __init__( device: Union[torch.device, str] = "cuda", ub_bulk_wgrad: bool = False, ub_bulk_dgrad: bool = False, - ub_split_rs: bool = False, - ub_atomic_gemm_rs: bool = False, - ub_split_ag: bool = False, - ub_atomic_gemm_ag: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_ag: bool = False, ) -> None: super().__init__() @@ -1218,29 +1216,18 @@ def __init__( self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() self.ub_bulk_wgrad = ub_bulk_wgrad self.ub_bulk_dgrad = ub_bulk_dgrad - self.ub_split_rs = ub_split_rs - self.ub_split_ag = ub_split_ag - self.ub_atomic_gemm_rs = ub_atomic_gemm_rs - self.ub_atomic_gemm_ag = ub_atomic_gemm_ag + self.ub_overlap_rs = ub_overlap_rs + self.ub_overlap_ag = ub_overlap_ag # GEMM-GELU fusion is currently only supported with split GEMM-AG overlap - self.gemm_gelu_fusion = (bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and - self.activation == 'gelu' and self.ub_split_ag) - - if (ub_bulk_wgrad # pylint: disable=too-many-boolean-expressions - or ub_bulk_dgrad - or ub_split_rs - or ub_split_ag - or ub_atomic_gemm_rs - or ub_atomic_gemm_ag): + self.gemm_gelu_fusion = \ + (bool(int(os.getenv("NVTE_GEMM_GELU_FUSION", "0"))) and + self.activation == 'gelu' and not get_ub("fc1_fprop").is_atomic_gemm()) + + if any([ub_bulk_wgrad, ub_bulk_dgrad, ub_overlap_rs, ub_overlap_ag]): assert ( tex.userbuf_comm_available() ), "Userbuffer communication backend not available." - if ub_atomic_gemm_rs or ub_atomic_gemm_ag: - warnings.warn( - "Atomic gemm uses a beta API from cublas and is not tested for all use cases." - ) - if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -1490,10 +1477,8 @@ def forward( self.primary_weights_in_fp8, self.ub_bulk_wgrad, self.ub_bulk_dgrad, - self.ub_split_rs, - self.ub_atomic_gemm_rs, - self.ub_split_ag, - self.ub_atomic_gemm_ag, + self.ub_overlap_rs, + self.ub_overlap_ag, self.gemm_gelu_fusion, ) out = fwd_fn(*args) diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index d22242abb4..1f7898a592 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -3,7 +3,6 @@ # See LICENSE for license information. """Linear API""" -import warnings from typing import Union, Optional, Callable, Tuple, List, Dict, Any import torch @@ -79,10 +78,8 @@ def forward( parallel_mode: Union[str, None], is_grad_enabled: bool, primary_weights_in_fp8: bool, - ub_split_rs: bool, - ub_split_ag: bool, - ub_atomic_gemm_rs: bool, - ub_atomic_gemm_ag: bool, + ub_overlap_rs: bool, + ub_overlap_ag: bool, ub_name: str ) -> torch.Tensor: # Make sure input dimensions are compatible @@ -94,14 +91,8 @@ def forward( assert_dim_for_fp8_exec(weight) update_fp8_weights = is_first_microbatch is None or is_first_microbatch - - if ub_split_rs or ub_atomic_gemm_rs: - tp_world_size = get_distributed_world_size(tp_group) - if tp_world_size == 1: - ub_split_rs = False - ub_atomic_gemm_rs = False - if ub_atomic_gemm_rs or ub_atomic_gemm_ag: - assert fp8, "AtomicGemm overlap supported only for FP8 GEMM." + tp_world_size = get_distributed_world_size(tp_group) + ub_overlap_rs = False if tp_world_size == 1 else ub_overlap_rs # Cast input to expected dtype inputmat = cast_if_needed(inputmat, activation_dtype) @@ -180,14 +171,23 @@ def forward( proj_out_index, meta_tensor, proj_out_tetype, proj_out_pttype = ( None, None, None, activation_dtype) - if ub_split_rs or ub_atomic_gemm_rs: + if ub_overlap_rs: ub_obj_projout = get_ub(ub_name+"_fprop") out = ub_obj_projout.get_ubuf_output(1) dim_size = list(inputmat_total.size()) dim_size[0] = dim_size[0] // tp_world_size dim_size[1] = weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - + if ub_obj_projout.is_p2p_overlap(): + if ub_obj_projout.is_atomic_gemm(): + ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + if ub_obj_projout.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_RS + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_obj_projout.is_fp8_ubuf(): proj_out_index = tex.FP8FwdTensors.GEMM1_OUTPUT meta_tensor = fp8_meta["scaling_fwd"] @@ -199,8 +199,6 @@ def forward( dim_size[1] = weight.size(0) out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) - ub_algo=tex.UbufOverlapAlgo.ATOMIC_GEMM_RS if ub_atomic_gemm_rs else None - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else ub_algo _ = fp8_gemm( weight_fp8._data, fp8_meta["scaling_fwd"].scale_inv, @@ -216,9 +214,9 @@ def forward( use_bias=use_bias, use_split_accumulator=_2X_ACC_FPROP, out=out, - ub_algo=ub_algo, - ub=ub_obj_projout if (ub_split_rs or ub_atomic_gemm_rs) else None, - extra_output_tensor=rs_out if (ub_split_rs or ub_atomic_gemm_rs) else None, + ub_algo=ub_algo if ub_overlap_rs else None, + ub=ub_obj_projout if ub_overlap_rs else None, + extra_output_tensor=rs_out if ub_overlap_rs else None, out_index=proj_out_index, fp8_meta_tensor = meta_tensor, D_dtype = proj_out_tetype, @@ -238,13 +236,17 @@ def forward( fp8_meta["scaling_fwd"].amax_history[0][tex.FP8FwdTensors.GEMM1_WEIGHT] = \ torch.max(-amin, amax).float() - if ub_split_rs: - ub_obj_projout = get_ub("proj_fprop") + if ub_overlap_rs: + ub_obj_projout = get_ub(ub_name+"_fprop") out = ub_obj_projout.get_ubuf_output(1) dim_size = list(inputmat_total.size()) - dim_size[0] = dim_size[0] // tp_world_size + dim_size[0] = dim_size[0] // get_distributed_world_size(tp_group) dim_size[1] = weight.size(0) rs_out = torch.empty(dim_size, dtype=activation_dtype, device=inputmat_total.device) + if ub_obj_projout.is_p2p_overlap(): + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS else: dim_size = list(inputmat_total.size()) dim_size[1] = weight.size(0) @@ -258,9 +260,9 @@ def forward( bias=bias, use_bias=use_bias, out=out, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_RS if ub_split_rs else None, - ub=ub_obj_projout if ub_split_rs else None, - extra_output_tensor=rs_out if ub_split_rs else None, + ub_algo=ub_algo if ub_overlap_rs else None, + ub=ub_obj_projout if ub_overlap_rs else None, + extra_output_tensor=rs_out if ub_overlap_rs else None, ) if is_grad_enabled: @@ -307,14 +309,13 @@ def forward( ctx.inp_shape = inp.shape ctx.parallel_mode = parallel_mode ctx.tp_group = tp_group - ctx.ub_split_ag = ub_split_ag - ctx.ub_atomic_gemm_ag = ub_atomic_gemm_ag + ctx.ub_overlap_ag = ub_overlap_ag ctx.ub_name = ub_name ctx.tp_size = tp_size ctx.requires_dgrad = inp.requires_grad # Row Parallel Linear - if ub_split_rs or ub_atomic_gemm_rs: + if ub_overlap_rs: out = rs_out elif parallel_mode == "row" and sequence_parallel: out, _ = reduce_scatter_along_first_dim(out, tp_group) @@ -350,16 +351,16 @@ def backward( weight_t_fp8 = weight.transpose( update_cache="reuse_only" if ctx.is_first_microbatch is None else "lazy", ) - - if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: - tp_world_size = get_distributed_world_size(ctx.tp_group) - if tp_world_size == 1: - ctx.ub_split_ag = False - ctx.ub_atomic_gemm_ag = False - if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: + tp_world_size = get_distributed_world_size(ctx.tp_group) + ctx.ub_overlap_ag = False if tp_world_size == 1 else ctx.ub_overlap_ag + if ctx.ub_overlap_ag: dim_size = list(grad_output.size()) dim_size[0] = dim_size[0] * tp_world_size ctx.ub_obj_gradout = get_ub(ctx.ub_name+"_dgrad") + if ctx.ub_obj_gradout.is_atomic_gemm(): + ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P + else: + ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P ( grad_output, grad_output_c, @@ -397,8 +398,6 @@ def backward( ctx.fp8_meta["recipe"], fprop_tensor=False ) - ub_algo = tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None - ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG if ctx.ub_atomic_gemm_ag else ub_algo if ctx.requires_dgrad: if ctx.fp8: dgrad, _ = fp8_gemm( @@ -413,8 +412,8 @@ def backward( ctx.activation_dtype, get_workspace(), use_split_accumulator=_2X_ACC_DGRAD, - ub_algo=ub_algo, - ub=ctx.ub_obj_gradout if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag else None, + ub_algo=ub_algo if ctx.ub_overlap_ag else None, + ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, ) else: dgrad, _, _ = gemm( @@ -424,8 +423,9 @@ def backward( get_workspace(), layout="NN", grad=True, - ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG if ctx.ub_split_ag else None, - ub=ctx.ub_obj_gradout if ctx.ub_split_ag else None, + ub_algo=tex.UbufOverlapAlgo.SPLIT_PIPELINED_AG_P2P \ + if ctx.ub_overlap_ag else None, + ub=ctx.ub_obj_gradout if ctx.ub_overlap_ag else None, ) # Overlap dgrad-RS/AR with wgrad @@ -442,7 +442,7 @@ def backward( if ctx.fp8: # WGRAD if not ctx.fp8_meta["recipe"].override_linear_precision.wgrad: - if ctx.ub_split_ag or ctx.ub_atomic_gemm_ag: + if ctx.ub_overlap_ag: grad_output_t = tex.fp8_transpose(grad_output_c, fp8_dtype_backward) if inputmat_t_total is None: inputmat_t_total = tex.fp8_transpose(inputmat_total, fp8_dtype_backward) @@ -542,8 +542,6 @@ def backward( None, None, None, - None, - None, ) @@ -629,10 +627,8 @@ def __init__( parallel_mode: Optional[str] = None, parameters_split: Optional[Union[Tuple[str, ...], Dict[str, int]]] = None, device: Union[torch.device, str] = "cuda", - ub_split_rs: bool = False, - ub_split_ag: bool = False, - ub_atomic_gemm_rs: bool = False, - ub_atomic_gemm_ag: bool = False, + ub_overlap_rs: bool = False, + ub_overlap_ag: bool = False, ub_name: Optional[str] = None, ) -> None: super().__init__() @@ -645,28 +641,18 @@ def __init__( self.return_bias = return_bias self.apply_bias = bias and not return_bias self.primary_weights_in_fp8 = FP8GlobalStateManager.with_fp8_parameters() - self.ub_split_rs = ub_split_rs - self.ub_split_ag = ub_split_ag - self.ub_atomic_gemm_rs = ub_atomic_gemm_rs - self.ub_atomic_gemm_ag = ub_atomic_gemm_ag - if any([ub_atomic_gemm_rs, ub_atomic_gemm_ag]): + self.ub_overlap_rs = ub_overlap_rs + self.ub_overlap_ag = ub_overlap_ag + if ub_overlap_rs or ub_overlap_ag: assert ub_name is not None, "Userbuffer name [string] is not set." + assert ( + tex.userbuf_comm_available() + ), "Userbuffer communication backend not available." self.ub_name = ub_name self.get_rng_state_tracker = get_rng_state_tracker if device == 'meta': assert parameters_split is None, ("Cannot split module parameters " "on 'meta' device.") - - if ub_split_rs or ub_split_ag or ub_atomic_gemm_rs: - assert ( - tex.userbuf_comm_available() - ), "Userbuffer communication backend not available." - - if ub_atomic_gemm_rs or ub_atomic_gemm_ag: - warnings.warn( - "Atomic gemm uses a beta API from cublas and is not tested for all use cases." - ) - if tp_group is None: self.tp_size = tp_size if tp_size == 1: @@ -930,10 +916,8 @@ def forward( self.parallel_mode, torch.is_grad_enabled(), self.primary_weights_in_fp8, - self.ub_split_rs, - self.ub_split_ag, - self.ub_atomic_gemm_rs, - self.ub_atomic_gemm_ag, + self.ub_overlap_rs, + self.ub_overlap_ag, self.ub_name, ) out = linear_fn(*args) diff --git a/transformer_engine/pytorch/transformer.py b/transformer_engine/pytorch/transformer.py index ab69f8e690..a0fd231913 100644 --- a/transformer_engine/pytorch/transformer.py +++ b/transformer_engine/pytorch/transformer.py @@ -259,10 +259,8 @@ def __init__( ub_tp_comm_overlap: bool = False, ub_bulk_wgrad: bool = True, ub_bulk_dgrad: bool = True, - ub_split_ag: bool = True, - ub_split_rs: bool = True, - ub_atomic_gemm_ag: bool = False, - ub_atomic_gemm_rs: bool = False, + ub_overlap_ag: bool = True, + ub_overlap_rs: bool = True, bias: bool = True, activation: str = 'gelu', normalization: str = "LayerNorm", @@ -282,21 +280,8 @@ def __init__( params_dtype = torch.get_default_dtype() if params_dtype is None else params_dtype ub_bulk_wgrad = ub_tp_comm_overlap and ub_bulk_wgrad ub_bulk_dgrad = ub_tp_comm_overlap and ub_bulk_dgrad - ub_split_ag = ub_tp_comm_overlap and ub_split_ag - ub_split_rs = ub_tp_comm_overlap and ub_split_rs - ub_atomic_gemm_rs = ub_tp_comm_overlap and ub_atomic_gemm_rs - assert ( - not (ub_split_rs and ub_atomic_gemm_rs) - ), "Only one type of RS overlap ub_split_rs/ub_atomic_gemm_rs should be enabled." - ub_atomic_gemm_ag = ub_tp_comm_overlap and ub_atomic_gemm_ag - assert ( - not (ub_split_ag and ub_atomic_gemm_ag) - ), "Only one type of AG overlap ub_split_ag/ub_atomic_gemm_ag should be enabled." - - if ub_atomic_gemm_rs or ub_atomic_gemm_ag: - warnings.warn( - "Atomic gemm uses a beta API from cublas and is not tested for all use cases." - ) + ub_overlap_ag = ub_tp_comm_overlap and ub_overlap_ag + ub_overlap_rs = ub_tp_comm_overlap and ub_overlap_rs bias_dropout_fusion = bool(int(os.getenv("NVTE_BIAS_DROPOUT_FUSION", "1"))) self.layer_number = layer_number @@ -370,10 +355,8 @@ def __init__( "qkv_weight_interleaved" : qkv_weight_interleaved, "ub_bulk_wgrad" : ub_bulk_wgrad, "ub_bulk_dgrad" : ub_bulk_dgrad, - "ub_split_ag" : ub_split_ag, - "ub_split_rs" : ub_split_rs, - "ub_atomic_gemm_rs" : ub_atomic_gemm_rs, - "ub_atomic_gemm_ag" : ub_atomic_gemm_ag, + "ub_overlap_ag" : ub_overlap_ag, + "ub_overlap_rs" : ub_overlap_rs, "qkv_format" : self.attn_input_format, } @@ -427,10 +410,8 @@ def __init__( zero_centered_gamma=zero_centered_gamma, ub_bulk_wgrad=ub_bulk_wgrad, ub_bulk_dgrad=ub_bulk_dgrad, - ub_split_rs=ub_split_rs, - ub_split_ag=ub_split_ag, - ub_atomic_gemm_rs=ub_atomic_gemm_rs, - ub_atomic_gemm_ag=ub_atomic_gemm_ag, + ub_overlap_rs=ub_overlap_rs, + ub_overlap_ag=ub_overlap_ag, activation=activation, normalization=normalization, device=device, From 8e672ff0758033c348e263dbcd6a4b3578c01161 Mon Sep 17 00:00:00 2001 From: Reese Wang Date: Fri, 22 Mar 2024 23:00:55 +0800 Subject: [PATCH 10/16] [JAX] Refactor fused attention (#711) * Remove unused headers Signed-off-by: Reese Wang * Unify the fused attn workspace size cpp code Signed-off-by: Reese Wang * Reduce the skipped cases Signed-off-by: Reese Wang * Rename self/cross attention to qkvpacked/kvpacked Signed-off-by: Reese Wang * Update attention mask docs Signed-off-by: Reese Wang * Refine the attn mask implementations Signed-off-by: Reese Wang --------- Signed-off-by: Reese Wang --- tests/jax/test_distributed_fused_attn.py | 40 +- tests/jax/test_fused_attn.py | 102 +- tests/jax/test_layer.py | 2 + tests/jax/test_praxis_layers.py | 33 +- transformer_engine/jax/cpp_extensions.py | 1535 ++++++-------------- transformer_engine/jax/csrc/extensions.cpp | 8 - transformer_engine/jax/csrc/modules.cpp | 795 ++++------ transformer_engine/jax/csrc/modules.h | 53 +- transformer_engine/jax/flax/transformer.py | 128 +- transformer_engine/jax/fused_attn.py | 271 ++-- 10 files changed, 1003 insertions(+), 1964 deletions(-) diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 580125290b..d5b23667d1 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -16,7 +16,7 @@ from utils import make_causal_mask, make_self_mask from transformer_engine.jax import fp8_autocast from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available -from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn +from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout DTYPES = [jnp.float16, jnp.bfloat16] @@ -86,15 +86,15 @@ def test_self_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, dat def target_func(qkv, bias, mask): return jnp.mean( - self_fused_attn(qkv, - bias, - mask, - None, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training)) + fused_attn_qkvpacked(qkv, + bias, + mask, + None, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_prob, + is_training=is_training)) def ref_func(qkv, bias, mask): query, key, value = jnp.split(qkv, [1, 2], axis=-3) @@ -192,16 +192,16 @@ def test_cross_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, da def target_func(q, kv, mask): return jnp.mean( - cross_fused_attn(q, - kv, - None, - mask, - None, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training)) + fused_attn_kvpacked(q, + kv, + None, + mask, + None, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_prob, + is_training=is_training)) def ref_func(query, kv, mask): key, value = jnp.split(kv, [1], axis=-3) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 35a1e4218f..483f070559 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -2,8 +2,6 @@ # # See LICENSE for license information. """Tests for fused attention""" -import sys - from enum import Enum from dataclasses import dataclass from functools import partial @@ -21,7 +19,7 @@ from jax.typing import ArrayLike, DTypeLike from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout -from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn, fused_attn +from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine_jax import NVTE_Fused_Attn_Backend @@ -144,18 +142,22 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng case QKVLayout.BS3HD: query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value]) qkv = jnp.concatenate((query, key, value), axis=-3) - return self_fused_attn(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype) + return fused_attn_qkvpacked(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype) case QKVLayout.BSHD_BS2HD: key, value = map(partial(jnp.expand_dims, axis=-3), [key, value]) kv = jnp.concatenate((key, value), axis=-3) - return cross_fused_attn(query, kv, bias, mask, dropout_rng, - **kwargs).astype(query.dtype) + return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng, + **kwargs).astype(query.dtype) case QKVLayout.BSHD_BSHD_BSHD: return fused_attn(query, key, value, bias, mask, dropout_rng, **kwargs).astype(query.dtype) class BiasShape(Enum): + """ + Enum class to represent the different bias shapes used in the fused attention. + """ + BIAS_1HSS = '1HSS' BIAS_B1SS = 'B1SS' BIAS_BHSS = 'BHSS' @@ -188,17 +190,16 @@ def _check_configs(self): if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv: pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.") - self.backend = FusedAttnHelper( - self.dtype, self.dtype, self.qkv_layout.value, self.attn_bias_type.value, - self.attn_mask_type.value, self.dropout_prob, self.num_heads_q, self.num_heads_kv, - self.max_seqlen_q, self.max_seqlen_kv, self.head_dim).get_fused_attn_backend() + self.backend = FusedAttnHelper(self.dtype, self.dtype, self.qkv_layout.value, + self.attn_bias_type.value, self.attn_mask_type.value, + self.dropout_prob, self.num_heads_q, self.num_heads_kv, + self.max_seqlen_q, self.max_seqlen_kv, + self.head_dim).get_fused_attn_backend() if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: pytest.skip("Unsupported inputs combination or device compute capability.") - if self.bias_shape != BiasShape.BIAS_1HSS: - if self.attn_bias_type != AttnBiasType.POST_SCALE_BIAS: - pytest.skip("B1SS, BHSS and 11SS bias shapes require POST_SCALE_BIAS.") - elif self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: + if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS: + if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for " "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.") elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: @@ -213,7 +214,9 @@ def _setup_inputs(self): q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim) k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim) - if self.bias_shape == BiasShape.BIAS_1HSS: + if self.attn_bias_type == AttnBiasType.NO_BIAS: + bias_shape = None + elif self.bias_shape == BiasShape.BIAS_1HSS: bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv) elif self.bias_shape == BiasShape.BIAS_B1SS: bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv) @@ -222,7 +225,7 @@ def _setup_inputs(self): elif self.bias_shape == BiasShape.BIAS_11SS: bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv) else: - pytest.xfail("PyTest attempted to use an unrecognized bias layout!") + pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!") self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.) self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.) @@ -237,7 +240,7 @@ def _setup_inputs(self): cudnn_neg_inf = -2.**27. if self.dtype == jnp.bfloat16 else -2.**15. self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype) max_id = min(self.max_seqlen_q, self.max_seqlen_kv) - seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences + seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist() for i in range(1, len(seq_id)): self.bias = \ @@ -327,8 +330,8 @@ def grad_func(func, *args, **kwargs): **kwargs), arg_nums)) jitted_reference = jit( value_and_grad( - lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, - **kwargs), arg_nums)) + lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs), + arg_nums)) primitive_out, primitive_dgrad = jitted_primitive(*args) reference_out, reference_dgrad = jitted_reference(*args) @@ -361,10 +364,10 @@ def check_dqkv(primitive, reference, valid_len): primitive_dbias = jnp.float32(primitive_dgrad[3]) reference_dbias = jnp.float32(reference_dgrad[3]) - assert_allclose( - primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:], - jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:]), - dtype=self.dtype) + assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:], + jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, + self.valid_len_kv:]), + dtype=self.dtype) # dbias padded part assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:], @@ -376,15 +379,13 @@ def check_dqkv(primitive, reference, valid_len): reference_dbias[..., :self.valid_len_q, :self.valid_len_kv], dtype=self.dtype) -@pytest.mark.parametrize('bias_shape', [ - pytest.param(BiasShape.BIAS_1HSS, id='1-H-S-S'), - pytest.param(BiasShape.BIAS_B1SS, id='B-1-S-S'), - pytest.param(BiasShape.BIAS_BHSS, id='B-H-S-S'), - pytest.param(BiasShape.BIAS_11SS, id='1-1-S-S'), -]) -@pytest.mark.parametrize('attn_bias_type', [ - pytest.param(AttnBiasType.NO_BIAS, id='NO_BIAS'), - pytest.param(AttnBiasType.POST_SCALE_BIAS, id='POST_SCALE_BIAS'), + +@pytest.mark.parametrize('attn_bias_type, bias_shape', [ + pytest.param(AttnBiasType.NO_BIAS, None, id='NO_BIAS'), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id='POST_SCALE_BIAS-1HSS'), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id='POST_SCALE_BIAS-B1SS'), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id='POST_SCALE_BIAS-BHSS'), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id='POST_SCALE_BIAS-11SS'), ]) @pytest.mark.parametrize('attn_mask_type', [ pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'), @@ -399,31 +400,32 @@ def check_dqkv(primitive, reference, valid_len): ]) @pytest.mark.parametrize('dtype', [ pytest.param(jnp.bfloat16, id="BF16"), - pytest.param(jnp.float16, id="FP16") + pytest.param(jnp.float16, id="FP16"), ]) -@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d',[ - pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'), - pytest.param( 4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'), - pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'), - pytest.param( 4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'), - pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'), - pytest.param( 4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA') +@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d', [ + pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'), + pytest.param(4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'), + pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'), + pytest.param(4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'), + pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'), + pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA'), ]) @pytest.mark.parametrize('dropout_prob', [ pytest.param(0.0, id="DROP_0.0"), - pytest.param(0.1, id="DROP_0.1") -]) -@pytest.mark.parametrize('is_training', [ - pytest.param(True, id='TRAINING'), - pytest.param(False, id='INFERENCE'), + pytest.param(0.1, id="DROP_0.1"), ]) class TestFusedAttn: """ Fused attention tester """ + @staticmethod - def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, - dropout_prob, dtype, is_training, qkv_layout, bias_shape): + @pytest.mark.parametrize('is_training', [ + pytest.param(True, id='TRAINING'), + pytest.param(False, id='INFERENCE'), + ]) + def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob, + dtype, is_training, qkv_layout, bias_shape): """ Test forward with parameterized configs """ @@ -432,13 +434,11 @@ def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, runner.test_forward() @staticmethod - def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, - dropout_prob, dtype, is_training, qkv_layout, bias_shape): + def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob, + dtype, qkv_layout, bias_shape): """ Test backward with parameterized configs """ - if not is_training: - pytest.skip("Backward pass does not support inference.") runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob, dtype, True, qkv_layout, bias_shape) runner.test_backward() diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 175819e115..1b7b4087d0 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -449,6 +449,7 @@ def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): hidden_dropout_dims=(sequence_dim,), intermediate_dropout_dims=(sequence_dim,), layer_type=TransformerLayerType.DECODER, + self_attn_mask_type='padding_causal', dtype=dtype, **te_layer_attrs) ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs, @@ -497,6 +498,7 @@ def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e- hidden_dropout_dims=(sequence_dim,), intermediate_dropout_dims=(sequence_dim,), layer_type=TransformerLayerType.DECODER, + self_attn_mask_type='padding_causal', dtype=dtype, **te_layer_attrs) ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs, diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py index d9c9e17e97..43581f1015 100644 --- a/tests/jax/test_praxis_layers.py +++ b/tests/jax/test_praxis_layers.py @@ -730,8 +730,13 @@ class TestDotProductAttn(TestLayer): def input_getter(self, shape, dtype): key = jax.random.PRNGKey(seed=1234) q_key, k_key, v_key = jax.random.split(key, 3) - return list(map(partial(jax.random.normal, shape=shape, dtype=dtype), - [q_key, k_key, v_key])) + b, s, *_ = shape + if self.attrs[DotProductAttnAttr.TRANSPOSE_BS]: + b, s = s, b + mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8) + return [ + *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask + ] def get_layer_name(self): return 'dot_product_attn' @@ -765,6 +770,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): @pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS) def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): + self.attrs = attrs praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) @@ -853,9 +859,11 @@ class MultiHeadAttnAttr: class TestMultiHeadAttn(TestLayer): def input_getter(self, shape, dtype): - data_key = jax.random.PRNGKey(seed=1234) - return (jax.random.normal(data_key, shape, - dtype), jax.random.normal(data_key, shape, dtype)) + key = jax.random.PRNGKey(seed=1234) + q_key, kv_key = jax.random.split(key, 2) + s, b, *_ = shape + mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8) + return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask] def get_layer_name(self): return 'multi_head_attn' @@ -1183,9 +1191,15 @@ class TransformerLayerAttr: class TestTransformer(TestLayer): def input_getter(self, shape, dtype): - data_key = jax.random.PRNGKey(seed=1234) - return (jax.random.normal(data_key, shape, - dtype), jax.random.normal(data_key, shape, dtype)) + key = jax.random.PRNGKey(seed=1234) + q_key, kv_key = jax.random.split(key, 2) + b, s, *_ = shape + if self.attrs[TransformerLayerAttr.TRANSPOSE_BS]: + b, s = s, b + mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8) + return [ + *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask + ] def get_layer_name(self): return 'transformerlayer' @@ -1277,6 +1291,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): @pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS) def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): + self.attrs = attrs praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) @@ -1292,7 +1307,7 @@ def test_forward_backward_fp8(self, fp8_format, rtol=1e-05, atol=1e-08): - + self.attrs = attrs ds = DelayedScaling(fp8_format=fp8_format) with fp8_autocast(enabled=True, fp8_recipe=ds): praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index ba1d44318c..08bcb94239 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -1368,14 +1368,13 @@ def batcher(batched_args, batch_dims, *, scale_factor): @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledSoftmaxFwdPrimitive.forward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos - ) + scale_factor, mesh, arg_infos, result_infos) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): - return ScaledSoftmaxFwdPrimitive.forward_partition( - ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos - ) + return ScaledSoftmaxFwdPrimitive.forward_partition(ScaledSoftmaxFwdPrimitive.impl, + scale_factor, mesh, arg_infos, + result_infos) register_primitive(ScaledSoftmaxFwdPrimitive) @@ -1444,14 +1443,13 @@ def batcher(batched_args, batch_dims, *, scale_factor): @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledSoftmaxBwdPrimitive.backward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos - ) + scale_factor, mesh, arg_infos, result_infos) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): - return ScaledSoftmaxBwdPrimitive.backward_partition( - ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos - ) + return ScaledSoftmaxBwdPrimitive.backward_partition(ScaledSoftmaxBwdPrimitive.impl, + scale_factor, mesh, arg_infos, + result_infos) register_primitive(ScaledSoftmaxBwdPrimitive) @@ -1581,14 +1579,12 @@ def batcher(batched_args, batch_dims, *, scale_factor): @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos,result_infos - ) + scale_factor, mesh, arg_infos, result_infos) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledMaskedSoftmaxFwdPrimitive.backward_partition( - ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos - ) + ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos) register_primitive(ScaledMaskedSoftmaxFwdPrimitive) @@ -1660,14 +1656,12 @@ def batcher(batched_args, batch_dims, *, scale_factor): @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos - ) + scale_factor, mesh, arg_infos, result_infos) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledMaskedSoftmaxBwdPrimitive.backward_partition( - ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos - ) + ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos) register_primitive(ScaledMaskedSoftmaxBwdPrimitive) @@ -1749,15 +1743,13 @@ def batcher(batched_args, batch_dims, *, scale_factor): @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos - ) + scale_factor, mesh, arg_infos, result_infos) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_partition( - ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, - arg_infos, result_infos - ) + ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, + result_infos) register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive) @@ -1829,15 +1821,13 @@ def batcher(batched_args, batch_dims, *, scale_factor): @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos - ) + scale_factor, mesh, arg_infos, result_infos) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_partition( - ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, - arg_infos, result_infos - ) + ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, + result_infos) register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) @@ -1859,16 +1849,16 @@ class FusedAttnHelper: Helper for the fused attention backend """ - q_type: jnp.dtype - kv_type: jnp.dtype + q_dtype: jnp.dtype + kv_dtype: jnp.dtype qkv_layout: NVTE_QKV_Layout attn_bias_type: NVTE_Bias_Type attn_mask_type: NVTE_Mask_Type dropout_probability: float - num_heads_q: int - num_heads_kv: int - max_seqlen_q: int - max_seqlen_kv: int + q_num_heads: int + kv_num_heads: int + q_max_seqlen: int + kv_max_seqlen: int head_dim: int def is_fused_attn_kernel_available(self): @@ -1878,11 +1868,38 @@ def is_fused_attn_kernel_available(self): def get_fused_attn_backend(self): """Get the fused attention kernel backend""" return transformer_engine_jax.get_fused_attn_backend( - jax_dtype_to_te_dtype(self.q_type), jax_dtype_to_te_dtype(self.kv_type), + jax_dtype_to_te_dtype(self.q_dtype), jax_dtype_to_te_dtype(self.kv_dtype), self.qkv_layout, self.attn_bias_type, self.attn_mask_type, self.dropout_probability, - self.num_heads_q, self.num_heads_kv, self.max_seqlen_q, self.max_seqlen_kv, + self.q_num_heads, self.kv_num_heads, self.q_max_seqlen, self.kv_max_seqlen, self.head_dim) + @staticmethod + def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout): + """Parse qkv aval""" + match qkv_layout: + case NVTE_QKV_Layout.NVTE_BS3HD: + *q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape + kv_batch_shape = q_batch_shape + kv_max_seqlen = q_max_seqlen + num_gqa_groups = attn_heads + kv_head_dim = q_head_dim + assert nqkv == 3 + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape + *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape + assert nkv == 2 + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape + *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape + assert k_aval.shape == v_aval.shape + case _: + raise ValueError(f"Unexpected {qkv_layout=}") + assert q_batch_shape == kv_batch_shape + assert q_head_dim == kv_head_dim + assert q_aval.dtype == k_aval.dtype == v_aval.dtype + + return (q_batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, q_head_dim) + @dataclass(frozen=True) class _FusedAttnRNGStateChecker: @@ -1933,46 +1950,50 @@ def generate_cu_seqlen(actual_seqlen): return cu_seqlen -class SelfFusedAttnFwdPrimitive(BasePrimitive): +class FusedAttnFwdPrimitive(BasePrimitive): """ - Self Fused Attention Forward Primitive + Fused Attention Forward Primitive """ - name = "te_self_fused_attn_forward" + name = "te_fused_attn_forward" multiple_results = True - impl_static_args = (4, 5, 6, 7, 8) + impl_static_args = (7, 8, 9, 10, 11, 12) inner_primitive = None outer_primitive = None @staticmethod - def abstract(qkv_aval, bias_aval, seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, - attn_mask_type, scaling_factor, dropout_probability, is_training): + def abstract(q_aval, k_aval, v_aval, bias_aval, q_seqlen_or_cu_seqlen_aval, + kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type, + qkv_layout, scaling_factor, dropout_probability, is_training): """ - Self fused attention fwd inner primitive abstract + Fused attention fwd abstract """ - # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen - del seqlen_or_cu_seqlen_aval - qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype) - *input_batch_shape, max_seqlen, nqkv, attn_heads, head_dim = qkv_aval.shape - assert nqkv == 3 - assert qkv_aval.dtype == bias_aval.dtype + q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) + k_dtype = dtypes.canonicalize_dtype(k_aval.dtype) + v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) + bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) + assert q_dtype == k_dtype == v_dtype == bias_dtype + assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype + + batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \ + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) - output_shape = (*input_batch_shape, max_seqlen, attn_heads, head_dim) - out_aval = qkv_aval.update(shape=output_shape, dtype=qkv_dtype) + output_shape = (*batch_shape, q_max_seqlen, attn_heads, head_dim) + out_aval = q_aval.update(shape=output_shape, dtype=q_dtype) # backend determines the softmax buffer shape/dtype - backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type, - attn_mask_type, dropout_probability, attn_heads, attn_heads, - max_seqlen, max_seqlen, head_dim).get_fused_attn_backend() + backend = FusedAttnHelper(q_dtype, k_dtype, qkv_layout, attn_bias_type, attn_mask_type, + dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, + kv_max_seqlen, head_dim).get_fused_attn_backend() if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: - softmax_shape = (*input_batch_shape, attn_heads, max_seqlen, max_seqlen) - softmax_dtype = qkv_dtype + softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) + softmax_dtype = q_dtype elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - softmax_shape = (*input_batch_shape, attn_heads, max_seqlen, 1) + softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: raise ValueError(f'Unsupported {backend=}') - softmax_aux_aval = qkv_aval.update(shape=softmax_shape, dtype=softmax_dtype) + softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with # 32-bit unsigned int to get the buffer size we need in the C++ kernel @@ -1990,32 +2011,32 @@ def abstract(qkv_aval, bias_aval, seqlen_or_cu_seqlen_aval, seed_aval, *, attn_b # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to # prepare for the active fused-attn backend - input_batch = reduce(operator.mul, input_batch_shape) - wkspace_info = transformer_engine_jax.get_self_fused_attn_fwd_workspace_sizes( - input_batch, bias_batch, max_seqlen, attn_heads, bias_heads, head_dim, - scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) - wkspace_aval = qkv_aval.update(shape=wkspace_info[0], - dtype=te_dtype_to_jax_dtype(wkspace_info[1])) + input_batch = reduce(operator.mul, batch_shape) + wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes( + input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, + bias_heads, head_dim, scaling_factor, dropout_probability, attn_bias_type, + attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), is_training) + wkspace_aval = q_aval.update(shape=wkspace_info[0], + dtype=te_dtype_to_jax_dtype(wkspace_info[1])) return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ - Self fused attention fwd outer primitive abstract + Fused attention fwd outer primitive abstract """ out_aval, softmax_aux_aval, rng_state_aval, _ = \ - SelfFusedAttnFwdPrimitive.abstract(*args, **kwargs) + FusedAttnFwdPrimitive.abstract(*args, **kwargs) return out_aval, softmax_aux_aval, rng_state_aval @staticmethod - def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): + def lowering(ctx, q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type, + attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training): """ - Self fused attention fwd lowering rules + Fused attention fwd lowering rules """ - operands = [qkv, bias, cu_seqlen, seed] + operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed] operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) @@ -2023,9 +2044,12 @@ def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, ] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - qkv_aval, bias_aval, *_ = ctx.avals_in - *input_batch_shape, max_seqlen, _, attn_heads, head_dim = qkv_aval.shape - input_batch = reduce(operator.mul, input_batch_shape) + q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in + + batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \ + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + + input_batch = reduce(operator.mul, batch_shape) if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 @@ -2036,137 +2060,137 @@ def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, wkspace_aval = ctx.avals_out[-1] opaque = transformer_engine_jax.pack_fused_attn_descriptor( - input_batch, bias_batch, max_seqlen, max_seqlen, - attn_heads, attn_heads, bias_heads, head_dim, wkspace_aval.size, - scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(qkv_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - is_training) + input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, + bias_heads, head_dim, wkspace_aval.size, scaling_factor, dropout_probability, + attn_bias_type, attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), is_training) - out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) + out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) return out @staticmethod - def impl(qkv, bias, seqlen, seed, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): - assert SelfFusedAttnFwdPrimitive.inner_primitive is not None + def impl(q, k, v, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type, qkv_layout, + scaling_factor, dropout_probability, is_training): + assert FusedAttnFwdPrimitive.inner_primitive is not None - cu_seqlen = generate_cu_seqlen(seqlen) + q_cu_seqlen = generate_cu_seqlen(q_seqlen) + kv_cu_seqlen = generate_cu_seqlen(kv_seqlen) - output, softmax_aux, rng_state, _ = SelfFusedAttnFwdPrimitive.inner_primitive.bind( - qkv, + output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind( + q, + k, + v, bias, - cu_seqlen, + q_cu_seqlen, + kv_cu_seqlen, seed, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, scaling_factor=scaling_factor, dropout_probability=dropout_probability, is_training=is_training) return output, softmax_aux, rng_state @staticmethod - def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): + def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, qkv_layout, + scaling_factor, dropout_probability, is_training): _check_valid_batch_dims(batch_dims) - assert SelfFusedAttnFwdPrimitive.outer_primitive is not None - qkv_bdim, _, _, seed_bdim = batch_dims + assert FusedAttnFwdPrimitive.outer_primitive is not None + q_bdim, *_, seed_bdim = batch_dims - out_bdims = qkv_bdim, qkv_bdim, seed_bdim - return SelfFusedAttnFwdPrimitive.outer_primitive.bind( - *batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training), out_bdims + out_bdims = q_bdim, q_bdim, seed_bdim + return FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training), out_bdims @staticmethod - def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, + def infer_sharding_from_operands(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training, mesh, arg_infos, result_infos): del attn_bias_type, attn_mask_type, scaling_factor del dropout_probability, is_training, result_infos - x_spec = get_padded_spec(arg_infos[0]) # (...batch, seqlen, 3, head, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-3], *x_spec[-2:])) - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec[:-4], x_spec[-2], x_spec[-4], None)) + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + match qkv_layout: + case NVTE_QKV_Layout.NVTE_BS3HD: + # q_spec = (...batch, q_seqlen, head, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:])) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)) + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + # q_spec = (...batch, q_seqlen, head, hidden) + # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-4])) + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + # q_spec = (...batch, q_seqlen, head, hidden) + # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3])) + case _: + raise ValueError(f"Unsupported {qkv_layout=}") rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) return (out_sharding, softmax_aux_sharding, rng_state_sharding) @staticmethod - def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, - mesh, arg_infos, result_infos): - del result_infos - x_spec = get_padded_spec(arg_infos[0]) # (...batch, seqlen, 3, head, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-3], *x_spec[-2:])) - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec[:-4], x_spec[-2], x_spec[-4], None)) - rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) - arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [rng_state_sharding]) + def partition(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability, + is_training, mesh, arg_infos, result_infos): + out_sharding = result_infos[0].sharding + softmax_aux_sharding = result_infos[1].sharding + rng_state_sharding = seed_sharding = NamedSharding(mesh, + PartitionSpec(get_all_mesh_axes(), None)) + arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - impl = partial(SelfFusedAttnFwdPrimitive.impl, + impl = partial(FusedAttnFwdPrimitive.impl, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, scaling_factor=scaling_factor, dropout_probability=dropout_probability, is_training=is_training) return mesh, impl, out_shardings, arg_shardings -register_primitive(SelfFusedAttnFwdPrimitive) - - -def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray | None, seqlen: jnp.ndarray, - seed: jnp.ndarray | None, attn_bias_type: NVTE_Bias_Type, - attn_mask_type: NVTE_Mask_Type, scaling_factor: float, - dropout_probability: float, is_training: bool): - """ - Wrapper for TE self fused attention fwd - Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2 - """ - checker = _FusedAttnRNGStateChecker() - seed = checker.check_seed(seed, dropout_probability, is_training) - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - assert bias is None - bias = jnp.zeros(0, dtype=qkv.dtype) - - return SelfFusedAttnFwdPrimitive.outer_primitive.bind(qkv, - bias, - seqlen, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) +register_primitive(FusedAttnFwdPrimitive) -class SelfFusedAttnBwdPrimitive(BasePrimitive): +class FusedAttnBwdPrimitive(BasePrimitive): """ - Self Fused Attention Backward Primitive + Fused Attention Backward Primitive """ - name = "te_self_fused_attn_backward" + name = "te_fused_attn_backward" multiple_results = True - impl_static_args = (7, 8, 9, 10, 11) + impl_static_args = (10, 11, 12, 13, 14, 15) inner_primitive = None outer_primitive = None @staticmethod - def abstract(qkv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, doutput_aval, - seqlen_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): + def abstract(q_aval, k_aval, v_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, + doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type, + attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training): """ - Self fused attention bwd abstract + Fused attention bwd abstract """ - del softmax_aux_aval, rng_state_aval, seqlen_or_cu_seqlen_aval + del softmax_aux_aval, rng_state_aval, output_aval - assert qkv_aval.dtype == bias_aval.dtype == output_aval.dtype == doutput_aval.dtype - *input_batch_shape, max_seqlen, nqkv, attn_heads, head_dim = qkv_aval.shape - assert nqkv == 3 - qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype) + q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) + k_dtype = dtypes.canonicalize_dtype(k_aval.dtype) + v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) + doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype) + assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype + assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype + + batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \ + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 @@ -2174,46 +2198,55 @@ def abstract(qkv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, *bias_batch_shape, bias_heads, _, _ = bias_aval.shape bias_batch = reduce(operator.mul, bias_batch_shape) - input_batch = reduce(operator.mul, input_batch_shape) + input_batch = reduce(operator.mul, batch_shape) wkspace_shape, wkspace_dtype = \ - transformer_engine_jax.get_self_fused_attn_bwd_workspace_sizes( - input_batch, bias_batch, max_seqlen, attn_heads, bias_heads, head_dim, - scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(qkv_aval.dtype), is_training - ) + transformer_engine_jax.get_fused_attn_bwd_workspace_sizes( + input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, + bias_heads, head_dim, scaling_factor, dropout_probability, attn_bias_type, + attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), is_training) - dqkv_aval = qkv_aval.update(shape=qkv_aval.shape, dtype=qkv_dtype) + dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) + dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype) + dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype) dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) - wkspace_aval = qkv_aval.update(shape=wkspace_shape, - dtype=te_dtype_to_jax_dtype(wkspace_dtype)) + wkspace_aval = q_aval.update(shape=wkspace_shape, + dtype=te_dtype_to_jax_dtype(wkspace_dtype)) - return dqkv_aval, dbias_aval, wkspace_aval + return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ - Self fused attention bwd outer primitive abstract + Fused attention fwd outer primitive abstract """ - dqkv_aval, dbias_aval, _ = SelfFusedAttnBwdPrimitive.abstract(*args, **kwargs) - return dqkv_aval, dbias_aval + dq_aval, dk_aval, dv_aval, dbias_aval, _ = \ + FusedAttnBwdPrimitive.abstract(*args, **kwargs) + return dq_aval, dk_aval, dv_aval, dbias_aval @staticmethod - def lowering(ctx, qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen, *, - attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training): + def lowering(ctx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, + kv_cu_seqlen, *, attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, + dropout_probability, is_training): """ - Self fused attention bwd lowering rules + Fused attention bwd lowering rules """ - operands = [qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen] + operands = [ + q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen + ] operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) for output in ctx.avals_out ] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - qkv_aval, bias_aval, *_ = ctx.avals_in - *input_batch_shape, max_seqlen, _, attn_heads, head_dim = qkv_aval.shape - input_batch = reduce(operator.mul, input_batch_shape) + q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in + + batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \ + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + + input_batch = reduce(operator.mul, batch_shape) if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 @@ -2224,780 +2257,245 @@ def lowering(ctx, qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen, wkspace_aval = ctx.avals_out[-1] opaque = transformer_engine_jax.pack_fused_attn_descriptor( - input_batch, bias_batch, max_seqlen, max_seqlen, - attn_heads, attn_heads, bias_heads, head_dim, wkspace_aval.size, - scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(qkv_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - is_training) + input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, + bias_heads, head_dim, wkspace_aval.size, scaling_factor, dropout_probability, + attn_bias_type, attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), is_training) - out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) + out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) return out @staticmethod - def impl(qkv, bias, softmax_aux, rng_state, output, doutput, seqlen, attn_bias_type, - attn_mask_type, scaling_factor, dropout_probability, is_training): - assert SelfFusedAttnBwdPrimitive.inner_primitive is not None + def impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen, + attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability, + is_training): + assert FusedAttnBwdPrimitive.inner_primitive is not None - cu_seqlen = generate_cu_seqlen(seqlen) + q_cu_seqlen = generate_cu_seqlen(q_seqlen) + kv_cu_seqlen = generate_cu_seqlen(kv_seqlen) - dqkv, dbias, _ = SelfFusedAttnBwdPrimitive.inner_primitive.bind( - qkv, + dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind( + q, + k, + v, bias, softmax_aux, rng_state, output, doutput, - cu_seqlen, + q_cu_seqlen, + kv_cu_seqlen, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, scaling_factor=scaling_factor, dropout_probability=dropout_probability, is_training=is_training) - return dqkv, dbias + return dq, dk, dv, dbias @staticmethod - def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): + def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, qkv_layout, + scaling_factor, dropout_probability, is_training): _check_valid_batch_dims(batch_dims) - assert SelfFusedAttnBwdPrimitive.outer_primitive is not None - qkv_bdim, *_ = batch_dims + assert FusedAttnBwdPrimitive.outer_primitive is not None + q_bdim, k_bdim, v_bdim, *_ = batch_dims - out_bdims = qkv_bdim, qkv_bdim - return SelfFusedAttnBwdPrimitive.outer_primitive.bind( - *batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training), out_bdims + out_bdims = q_bdim, k_bdim, v_bdim, q_bdim + return FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training), out_bdims @staticmethod - def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, + def infer_sharding_from_operands(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training, mesh, arg_infos, result_infos): - del attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, - del is_training, result_infos - x_spec = get_padded_spec(arg_infos[0]) - bias_spec = get_padded_spec(arg_infos[1]) - dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + del attn_bias_type, attn_mask_type, qkv_layout, scaling_factor + del dropout_probability, is_training, result_infos + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - return (dx_sharding, dbias_sharding) + return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) @staticmethod - def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, - mesh, arg_infos, result_infos): + def partition(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability, + is_training, mesh, arg_infos, result_infos): del result_infos - x_spec = get_padded_spec(arg_infos[0]) - bias_spec = get_padded_spec(arg_infos[1]) - dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = (dx_sharding, dbias_sharding) + out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) - def sharded_impl(qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen): - local_dx, local_dbias = SelfFusedAttnBwdPrimitive.impl( - qkv, + def sharded_impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, + kv_cu_seqlen): + local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl( + q, + k, + v, bias, softmax_aux, rng_state, output, doutput, - cu_seqlen, + q_cu_seqlen, + kv_cu_seqlen, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, scaling_factor=scaling_factor, dropout_probability=dropout_probability, is_training=is_training) global_dbias = local_dbias if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) - return local_dx, global_dbias + return local_dq, local_dk, local_dv, global_dbias return mesh, sharded_impl, out_shardings, arg_shardings -register_primitive(SelfFusedAttnBwdPrimitive) +register_primitive(FusedAttnBwdPrimitive) -def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray, - rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray, - seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, - attn_mask_type: NVTE_Mask_Type, scaling_factor: float, - dropout_probability: float, is_training: bool): +def fused_attn_fwd_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray, seqlen: jnp.ndarray, + seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, scaling_factor: float, + dropout_probability: float, is_training: bool): """ - Wrapper for TE self fused attention bwd - Return the gradients of self fused attention with packed qkv input + Wrapper for TE self fused attention fwd + Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 """ + checker = _FusedAttnRNGStateChecker() + seed = checker.check_seed(seed, dropout_probability, is_training) + if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: assert bias is None bias = jnp.zeros(0, dtype=qkv.dtype) - return SelfFusedAttnBwdPrimitive.outer_primitive.bind(qkv, - bias, - softmax_aux, - rng_state, - output, - doutput, - seqlen, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + _not_used = jnp.zeros(0, qkv.dtype) + return FusedAttnFwdPrimitive.outer_primitive.bind(qkv, + _not_used, + _not_used, + bias, + seqlen, + seqlen, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) -class CrossFusedAttnFwdPrimitive(BasePrimitive): +def fused_attn_bwd_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray, + rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray, + seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, scaling_factor: float, + dropout_probability: float, is_training: bool): + """ + Wrapper for TE self fused attention bwd + Return the gradients of self fused attention with packed qkv input """ - Cross Fused Attention Forward Primitive + if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + assert bias is None + bias = jnp.zeros(0, dtype=qkv.dtype) + dummy_input = jnp.zeros(0, dtype=qkv.dtype) + dqkv, *_, dbias = FusedAttnBwdPrimitive.outer_primitive.bind( + qkv, + dummy_input, + dummy_input, + bias, + softmax_aux, + rng_state, + output, + doutput, + seqlen, + seqlen, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + return dqkv, dbias + + +def fused_attn_fwd_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, + q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray, seed: jnp.ndarray, + attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, + scaling_factor: float, dropout_probability: float, is_training: bool): + """ + Wrapper for TE fused attention fwd with kvpacked inputs + Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 """ - name = "te_cross_fused_attn_forward" - multiple_results = True - impl_static_args = (6, 7, 8, 9, 10) - inner_primitive = None - outer_primitive = None + checker = _FusedAttnRNGStateChecker() + seed = checker.check_seed(seed, dropout_probability, is_training) - @staticmethod - def abstract(q_aval, kv_aval, bias_aval, q_seqlen_or_cu_seqlen_aval, - kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type, - scaling_factor, dropout_probability, is_training): - """ - Cross fused attention fwd abstract - """ - q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) - kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype) - bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) - assert q_dtype == kv_dtype == bias_dtype - assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype + if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + assert bias is None + bias = jnp.zeros(0, dtype=q.dtype) - *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape - *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = kv_aval.shape - assert q_batch_shape == kv_batch_shape - assert q_head_dim == kv_head_dim - assert nkv == 2 - out_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) + return FusedAttnFwdPrimitive.outer_primitive.bind(q, + kv, + jnp.zeros(0, q.dtype), + bias, + q_seqlen, + kv_seqlen, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) - # backend determines the softmax buffer shape/dtype - backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD, - attn_bias_type, attn_mask_type, dropout_probability, attn_heads, - num_gqa_groups, q_max_seqlen, kv_max_seqlen, - q_head_dim).get_fused_attn_backend() - if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: - softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) - softmax_dtype = q_dtype - elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, 1) - softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) - else: - raise ValueError(f'Unsupported {backend=}') - softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) - - # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with - # 32-bit unsigned int to get the buffer size we need in the C++ kernel - checker = _FusedAttnRNGStateChecker() - seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype) - assert seed_dtype == checker.rng_state_dtype - rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) - rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - - # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to - # prepare for the active fused-attn backend - input_batch = reduce(operator.mul, q_batch_shape) - wkspace_info = transformer_engine_jax.get_cross_fused_attn_fwd_workspace_sizes( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, q_head_dim, - scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), is_training) - wkspace_aval = q_aval.update(shape=wkspace_info[0], - dtype=te_dtype_to_jax_dtype(wkspace_info[1])) - - return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval - - @staticmethod - def outer_abstract(*args, **kwargs): - """ - Cross fused attention fwd outer primitive abstract - """ - out_aval, softmax_aux_aval, rng_state_aval, _ = \ - CrossFusedAttnFwdPrimitive.abstract(*args, **kwargs) - return out_aval, softmax_aux_aval, rng_state_aval - - @staticmethod - def lowering(ctx, q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type, - attn_mask_type, scaling_factor, dropout_probability, is_training): - """ - Cross fused attention fwd lowering rules - """ - operands = [q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed] - operand_shapes = map(lambda x: x.type.shape, operands) - out_types = [ - ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) - for output in ctx.avals_out - ] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - q_aval, kv_aval, bias_aval, *_ = ctx.avals_in - *input_batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape - *_, kv_max_seqlen, _, num_gqa_groups, _ = kv_aval.shape - input_batch = reduce(operator.mul, input_batch_shape) - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - - wkspace_aval = ctx.avals_out[-1] - - opaque = transformer_engine_jax.pack_fused_attn_descriptor( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, head_dim, - wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - is_training) - - out = custom_caller(CrossFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) - - return out - - @staticmethod - def impl(q, kv, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): - assert CrossFusedAttnFwdPrimitive.inner_primitive is not None - - q_cu_seqlen = generate_cu_seqlen(q_seqlen) - kv_cu_seqlen = generate_cu_seqlen(kv_seqlen) - - output, softmax_aux, rng_state, _ = CrossFusedAttnFwdPrimitive.inner_primitive.bind( - q, - kv, - bias, - q_cu_seqlen, - kv_cu_seqlen, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - return output, softmax_aux, rng_state - - @staticmethod - def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): - _check_valid_batch_dims(batch_dims) - assert CrossFusedAttnFwdPrimitive.outer_primitive is not None - q_bdim, *_, seed_bdim = batch_dims - - out_bdims = q_bdim, q_bdim, seed_bdim - return CrossFusedAttnFwdPrimitive.outer_primitive.bind( - *batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training), out_bdims - - @staticmethod - def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training, mesh, arg_infos, - result_infos): - del attn_bias_type, attn_mask_type, scaling_factor - del dropout_probability, is_training, result_infos - q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden) - kv_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, 2, head, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4])) - rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) - return (out_sharding, softmax_aux_sharding, rng_state_sharding) - - @staticmethod - def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, - mesh, arg_infos, result_infos): - del result_infos - q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden) - kv_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, 2, head, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4])) - rng_state_sharding = seed_sharding = NamedSharding(mesh, - PartitionSpec(get_all_mesh_axes(), None)) - arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) - out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - impl = partial(CrossFusedAttnFwdPrimitive.impl, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - return mesh, impl, out_shardings, arg_shardings - - -register_primitive(CrossFusedAttnFwdPrimitive) - - -def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, q_seqlen: jnp.ndarray, - kv_seqlen: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, - attn_mask_type: NVTE_Mask_Type, scaling_factor: float, - dropout_probability: float, is_training: bool): - """ - Wrapper for TE cross fused attention fwd - Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2 - """ - checker = _FusedAttnRNGStateChecker() - seed = checker.check_seed(seed, dropout_probability, is_training) - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - assert bias is None - bias = jnp.zeros(0, dtype=q.dtype) - - return CrossFusedAttnFwdPrimitive.outer_primitive.bind(q, - kv, - bias, - q_seqlen, - kv_seqlen, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - - -class CrossFusedAttnBwdPrimitive(BasePrimitive): - """ - Cross Fused Attention Backward Primitive - """ - name = "te_cross_fused_attn_backward" - multiple_results = True - impl_static_args = (9, 10, 11, 12, 13) - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract(q_aval, kv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, - doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type, - attn_mask_type, scaling_factor, dropout_probability, is_training): - """ - Cross fused attention bwd abstract - """ - del softmax_aux_aval, rng_state_aval, output_aval - - q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) - kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype) - bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) - doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype) - assert q_dtype == kv_dtype == bias_dtype == doutput_dtype - assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype - - *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape - *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = kv_aval.shape - assert q_batch_shape == kv_batch_shape - assert q_head_dim == kv_head_dim - assert nkv == 2 - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - - input_batch = reduce(operator.mul, q_batch_shape) - wkspace_shape, wkspace_dtype = \ - transformer_engine_jax.get_cross_fused_attn_bwd_workspace_sizes( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, q_head_dim, - scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), is_training - ) - - dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) - dkv_aval = kv_aval.update(shape=kv_aval.shape, dtype=kv_dtype) - dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) - wkspace_aval = q_aval.update(shape=wkspace_shape, - dtype=te_dtype_to_jax_dtype(wkspace_dtype)) - - return dq_aval, dkv_aval, dbias_aval, wkspace_aval - - @staticmethod - def outer_abstract(*args, **kwargs): - """ - Cross fused attention fwd outer primitive abstract - """ - dq_aval, dkv_aval, dbias_aval, _ = \ - CrossFusedAttnBwdPrimitive.abstract(*args, **kwargs) - return dq_aval, dkv_aval, dbias_aval - - @staticmethod - def lowering(ctx, q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, - kv_cu_seqlen, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): - """ - Cross fused attention bwd lowering rules - """ - operands = [q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen] - operand_shapes = map(lambda x: x.type.shape, operands) - out_types = [ - ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) - for output in ctx.avals_out - ] - - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - q_aval, kv_aval, bias_aval, *_ = ctx.avals_in - *input_batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape - *_, kv_max_seqlen, _, num_gqa_groups, _ = kv_aval.shape - input_batch = reduce(operator.mul, input_batch_shape) - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - - wkspace_aval = ctx.avals_out[-1] - - opaque = transformer_engine_jax.pack_fused_attn_descriptor( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, head_dim, - wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - is_training) - - out = custom_caller(CrossFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) - - return out - - @staticmethod - def impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen, - attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training): - assert CrossFusedAttnBwdPrimitive.inner_primitive is not None - - q_cu_seqlen = generate_cu_seqlen(q_seqlen) - kv_cu_seqlen = generate_cu_seqlen(kv_seqlen) - - dq, dkv, dbias, _ = CrossFusedAttnBwdPrimitive.inner_primitive.bind( - q, - kv, - bias, - softmax_aux, - rng_state, - output, - doutput, - q_cu_seqlen, - kv_cu_seqlen, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - return dq, dkv, dbias - - @staticmethod - def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): - _check_valid_batch_dims(batch_dims) - assert CrossFusedAttnBwdPrimitive.outer_primitive is not None - q_bdim, kv_bdim, *_ = batch_dims - - out_bdims = q_bdim, kv_bdim, q_bdim - return CrossFusedAttnBwdPrimitive.outer_primitive.bind( - *batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training), out_bdims - - @staticmethod - def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training, mesh, arg_infos, - result_infos): - del attn_bias_type, attn_mask_type, scaling_factor - del dropout_probability, is_training, result_infos - q_spec = get_padded_spec(arg_infos[0]) - kv_spec = get_padded_spec(arg_infos[1]) - bias_spec = get_padded_spec(arg_infos[2]) - dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec)) - dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - return (dq_sharding, dkv_sharding, dbias_sharding) - - @staticmethod - def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, - mesh, arg_infos, result_infos): - del result_infos - q_spec = get_padded_spec(arg_infos[0]) - kv_spec = get_padded_spec(arg_infos[1]) - bias_spec = get_padded_spec(arg_infos[2]) - dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec)) - dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = (dq_sharding, dkv_sharding, dbias_sharding) - - def sharded_impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, - kv_cu_seqlen): - local_dq, local_dkv, local_dbias = CrossFusedAttnBwdPrimitive.impl( - q, - kv, - bias, - softmax_aux, - rng_state, - output, - doutput, - q_cu_seqlen, - kv_cu_seqlen, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - global_dbias = local_dbias - if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: - global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) - return local_dq, local_dkv, global_dbias - - return mesh, sharded_impl, out_shardings, arg_shardings - - -register_primitive(CrossFusedAttnBwdPrimitive) - - -def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, - softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray, - doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray, - attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, - scaling_factor: float, dropout_probability: float, is_training: bool): - """ - Wrapper for TE cross fused attention bwd - Return the gradients of cross fused attention with packed kv input - """ - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - assert bias is None - bias = jnp.zeros(0, dtype=q.dtype) - - return CrossFusedAttnBwdPrimitive.outer_primitive.bind(q, - kv, - bias, - softmax_aux, - rng_state, - output, - doutput, - q_seqlen, - kv_seqlen, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - - -class FusedAttnFwdPrimitive(BasePrimitive): - """ - Fused Attention Forward Primitive - Query, key, value are seperated tensors - """ - name = "te_fused_attn_forward" - multiple_results = True - impl_static_args = (7, 8, 9, 10, 11) - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract(q_aval, k_aval, v_aval, bias_aval, q_seqlen_or_cu_seqlen_aval, - kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type, - scaling_factor, dropout_probability, is_training): - """ - Fused attention fwd abstract - """ - q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) - k_dtype = dtypes.canonicalize_dtype(k_aval.dtype) - v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) - bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) - assert q_dtype == k_dtype == v_dtype == bias_dtype - assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype - - *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape - *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape - assert q_batch_shape == kv_batch_shape - assert q_head_dim == kv_head_dim - assert k_aval.shape == v_aval.shape - out_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) - - # backend determines the softmax buffer shape/dtype - backend = FusedAttnHelper(q_dtype, k_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD, attn_bias_type, - attn_mask_type, dropout_probability, attn_heads, num_gqa_groups, - q_max_seqlen, kv_max_seqlen, q_head_dim).get_fused_attn_backend() - - if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: - softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) - softmax_dtype = q_dtype - elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, 1) - softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) - else: - raise ValueError(f'Unsupported {backend=}') - softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) - - # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with - # 32-bit unsigned int to get the buffer size we need in the C++ kernel - checker = _FusedAttnRNGStateChecker() - seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype) - assert seed_dtype == checker.rng_state_dtype - rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) - rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - - # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to - # prepare for the active fused-attn backend - input_batch = reduce(operator.mul, q_batch_shape) - wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, q_head_dim, - scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), is_training) - wkspace_aval = q_aval.update(shape=wkspace_info[0], - dtype=te_dtype_to_jax_dtype(wkspace_info[1])) - - return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval - - @staticmethod - def outer_abstract(*args, **kwargs): - """ - Fused attention fwd outer primitive abstract - """ - out_aval, softmax_aux_aval, rng_state_aval, _ = \ - FusedAttnFwdPrimitive.abstract(*args, **kwargs) - return out_aval, softmax_aux_aval, rng_state_aval - - @staticmethod - def lowering(ctx, q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type, - attn_mask_type, scaling_factor, dropout_probability, is_training): - """ - Fused attention fwd lowering rules - """ - operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed] - operand_shapes = map(lambda x: x.type.shape, operands) - out_types = [ - ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) - for output in ctx.avals_out - ] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in - *batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape - *_, kv_max_seqlen, num_gqa_groups, _ = k_aval.shape - assert k_aval.shape == v_aval.shape - input_batch = reduce(operator.mul, batch_shape) - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - - wkspace_aval = ctx.avals_out[-1] - - opaque = transformer_engine_jax.pack_fused_attn_descriptor( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, head_dim, - wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - is_training) - - out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) - - return out - - @staticmethod - def impl(q, k, v, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type, - scaling_factor, dropout_probability, is_training): - assert FusedAttnFwdPrimitive.inner_primitive is not None - - q_cu_seqlen = generate_cu_seqlen(q_seqlen) - kv_cu_seqlen = generate_cu_seqlen(kv_seqlen) - - output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind( - q, - k, - v, - bias, - q_cu_seqlen, - kv_cu_seqlen, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - return output, softmax_aux, rng_state - - @staticmethod - def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): - _check_valid_batch_dims(batch_dims) - assert FusedAttnFwdPrimitive.outer_primitive is not None - q_bdim, *_, seed_bdim = batch_dims - - out_bdims = q_bdim, q_bdim, seed_bdim - return FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training), out_bdims - - @staticmethod - def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training, mesh, arg_infos, - result_infos): - del attn_bias_type, attn_mask_type, scaling_factor - del dropout_probability, is_training, result_infos - q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden) - k_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, head, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3])) - rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) - return (out_sharding, softmax_aux_sharding, rng_state_sharding) - - @staticmethod - def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, - mesh, arg_infos, result_infos): - del result_infos - q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden) - k_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, head, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3])) - rng_state_sharding = seed_sharding = NamedSharding(mesh, - PartitionSpec(get_all_mesh_axes(), None)) - arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) - out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - impl = partial(FusedAttnFwdPrimitive.impl, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - return mesh, impl, out_shardings, arg_shardings - - -register_primitive(FusedAttnFwdPrimitive) +def fused_attn_bwd_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, + softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray, + doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray, + attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, + scaling_factor: float, dropout_probability: float, is_training: bool): + """ + Wrapper for TE fused attention bwd with kvpacked inputs + Return the gradients of fused attention with packed kv input + """ + if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + assert bias is None + bias = jnp.zeros(0, dtype=q.dtype) + dummy_input = jnp.zeros(0, q.dtype) + dq, dkv, _, dbias = FusedAttnBwdPrimitive.outer_primitive.bind( + q, + kv, + dummy_input, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + return dq, dkv, dbias def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, @@ -3006,7 +2504,7 @@ def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda scaling_factor: float, dropout_probability: float, is_training: bool): """ Wrapper for TE fused attention fwd, where query, key, value are seperated tensors - Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2 + Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 """ checker = _FusedAttnRNGStateChecker() seed = checker.check_seed(seed, dropout_probability, is_training) @@ -3015,228 +2513,20 @@ def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda assert bias is None bias = jnp.zeros(0, dtype=q.dtype) - return FusedAttnFwdPrimitive.outer_primitive.bind(q, - k, - v, - bias, - q_seqlen, - kv_seqlen, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - - -class FusedAttnBwdPrimitive(BasePrimitive): - """ - Fused Attention Backward Primitive - """ - name = "te_fused_attn_backward" - multiple_results = True - impl_static_args = (10, 11, 12, 13, 14) - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract(q_aval, k_aval, v_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, - doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type, - attn_mask_type, scaling_factor, dropout_probability, is_training): - """ - Fused attention bwd abstract - """ - del softmax_aux_aval, rng_state_aval, output_aval - - q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) - k_dtype = dtypes.canonicalize_dtype(k_aval.dtype) - v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) - bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) - doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype) - assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype - assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype - - *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape - *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape - assert q_batch_shape == kv_batch_shape - assert q_head_dim == kv_head_dim - assert k_aval.shape == v_aval.shape - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - - input_batch = reduce(operator.mul, q_batch_shape) - wkspace_shape, wkspace_dtype = \ - transformer_engine_jax.get_fused_attn_bwd_workspace_sizes( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, q_head_dim, - scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), is_training - ) - - dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) - dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype) - dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype) - dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) - wkspace_aval = q_aval.update(shape=wkspace_shape, - dtype=te_dtype_to_jax_dtype(wkspace_dtype)) - - return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval - - @staticmethod - def outer_abstract(*args, **kwargs): - """ - Fused attention fwd outer primitive abstract - """ - dq_aval, dk_aval, dv_aval, dbias_aval, _ = \ - FusedAttnBwdPrimitive.abstract(*args, **kwargs) - return dq_aval, dk_aval, dv_aval, dbias_aval - - @staticmethod - def lowering(ctx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, - kv_cu_seqlen, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): - """ - Fused attention bwd lowering rules - """ - operands = [ - q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen - ] - operand_shapes = map(lambda x: x.type.shape, operands) - out_types = [ - ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) - for output in ctx.avals_out - ] - - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in - *batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape - *_, kv_max_seqlen, num_gqa_groups, _ = k_aval.shape - assert k_aval.shape == v_aval.shape - input_batch = reduce(operator.mul, batch_shape) - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - - wkspace_aval = ctx.avals_out[-1] - - opaque = transformer_engine_jax.pack_fused_attn_descriptor( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, head_dim, - wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - is_training) - - out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) - - return out - - @staticmethod - def impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen, - attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training): - assert FusedAttnBwdPrimitive.inner_primitive is not None - - q_cu_seqlen = generate_cu_seqlen(q_seqlen) - kv_cu_seqlen = generate_cu_seqlen(kv_seqlen) - - dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind( - q, - k, - v, - bias, - softmax_aux, - rng_state, - output, - doutput, - q_cu_seqlen, - kv_cu_seqlen, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - return dq, dk, dv, dbias - - @staticmethod - def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): - _check_valid_batch_dims(batch_dims) - assert FusedAttnBwdPrimitive.outer_primitive is not None - q_bdim, k_bdim, v_bdim, *_ = batch_dims - - out_bdims = q_bdim, k_bdim, v_bdim, q_bdim - return FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training), out_bdims - - @staticmethod - def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training, mesh, arg_infos, - result_infos): - del attn_bias_type, attn_mask_type, scaling_factor - del dropout_probability, is_training, result_infos - q_spec = get_padded_spec(arg_infos[0]) - k_spec = get_padded_spec(arg_infos[1]) - v_spec = get_padded_spec(arg_infos[2]) - bias_spec = get_padded_spec(arg_infos[3]) - dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) - dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) - dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) - - @staticmethod - def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, - mesh, arg_infos, result_infos): - del result_infos - q_spec = get_padded_spec(arg_infos[0]) - k_spec = get_padded_spec(arg_infos[1]) - v_spec = get_padded_spec(arg_infos[2]) - bias_spec = get_padded_spec(arg_infos[3]) - dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) - dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) - dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) - - def sharded_impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, - kv_cu_seqlen): - local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl( - q, - k, - v, - bias, - softmax_aux, - rng_state, - output, - doutput, - q_cu_seqlen, - kv_cu_seqlen, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - global_dbias = local_dbias - if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: - global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) - return local_dq, local_dk, local_dv, global_dbias - - return mesh, sharded_impl, out_shardings, arg_shardings - - -register_primitive(FusedAttnBwdPrimitive) + return FusedAttnFwdPrimitive.outer_primitive.bind( + q, + k, + v, + bias, + q_seqlen, + kv_seqlen, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, @@ -3251,22 +2541,23 @@ def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: assert bias is None bias = jnp.zeros(0, dtype=q.dtype) - - return FusedAttnBwdPrimitive.outer_primitive.bind(q, - k, - v, - bias, - softmax_aux, - rng_state, - output, - doutput, - q_seqlen, - kv_seqlen, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + return FusedAttnBwdPrimitive.outer_primitive.bind( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) class GeluPrimitive(BasePrimitive): diff --git a/transformer_engine/jax/csrc/extensions.cpp b/transformer_engine/jax/csrc/extensions.cpp index 5faec6fd10..5e4ab4f205 100644 --- a/transformer_engine/jax/csrc/extensions.cpp +++ b/transformer_engine/jax/csrc/extensions.cpp @@ -49,10 +49,6 @@ pybind11::dict Registrations() { EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward); dict["te_scaled_upper_triang_masked_softmax_backward"] = EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward); - dict["te_self_fused_attn_forward"] = EncapsulateFunction(SelfFusedAttnForward); - dict["te_self_fused_attn_backward"] = EncapsulateFunction(SelfFusedAttnBackward); - dict["te_cross_fused_attn_forward"] = EncapsulateFunction(CrossFusedAttnForward); - dict["te_cross_fused_attn_backward"] = EncapsulateFunction(CrossFusedAttnBackward); dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); return dict; @@ -72,10 +68,6 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes); m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes); m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes); - m.def("get_self_fused_attn_fwd_workspace_sizes", &GetSelfFusedAttnForwardWorkspaceSizes); - m.def("get_self_fused_attn_bwd_workspace_sizes", &GetSelfFusedAttnBackwardWorkspaceSizes); - m.def("get_cross_fused_attn_fwd_workspace_sizes", &GetCrossFusedAttnForwardWorkspaceSizes); - m.def("get_cross_fused_attn_bwd_workspace_sizes", &GetCrossFusedAttnBackwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); diff --git a/transformer_engine/jax/csrc/modules.cpp b/transformer_engine/jax/csrc/modules.cpp index ceacc7c90f..1c4c468d51 100644 --- a/transformer_engine/jax/csrc/modules.cpp +++ b/transformer_engine/jax/csrc/modules.cpp @@ -11,14 +11,12 @@ #include #include -#include -#include #include #include #include #include "common/common.h" -#include "common/util/cuda_runtime.h" +#include "common/util/logging.h" #include "transformer_engine/activation.h" #include "transformer_engine/cast.h" #include "transformer_engine/fused_attn.h" @@ -96,13 +94,13 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin pybind11::bytes PackCustomCallFusedAttnDescriptor( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - size_t wkspace_size, float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - DType dtype, DType wkspace_dtype, bool is_training) { + size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, + bool is_training) { return PackOpaque(CustomCallFusedAttnDescriptor{ input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, - bias_heads, head_dim, wkspace_size, scaling_factor, dropout_probability, - bias_type, mask_type, dtype, wkspace_dtype, is_training}); + bias_heads, head_dim, wkspace_size, scaling_factor, dropout_probability, bias_type, + mask_type, qkv_layout, dtype, wkspace_dtype, is_training}); } void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream, @@ -942,12 +940,12 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float dropout_probability, - size_t q_num_heads, size_t kv_num_heads, + size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim) { auto backend = nvte_get_fused_attn_backend( static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, - mask_type, dropout_probability, q_num_heads, kv_num_heads, q_max_seqlen, kv_max_seqlen, + mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, head_dim); return backend; } @@ -1029,244 +1027,31 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, } } -pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes( - size_t input_batch, size_t bias_batch, size_t max_seqlen, - size_t attn_heads, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) { - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; - - auto qkv_shape = std::vector{input_batch * max_seqlen, 3, attn_heads, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, max_seqlen, max_seqlen}; - - auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype); - auto cu_seqlens_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - auto o_tensor = TensorWrapper( - nullptr, std::vector{input_batch, max_seqlen, attn_heads, head_dim}, dtype); - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); - auto rng_state_tensor = TensorWrapper(nullptr, std::vector{2}, DType::kInt64); - - auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, attn_heads, max_seqlen, max_seqlen, head_dim); - - NVTETensorPack aux_output_tensors; - nvte_tensor_pack_create(&aux_output_tensors); - - TensorWrapper query_workspace_tensor; - nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(), - rng_state_tensor.data(), max_seqlen, is_training, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, - query_workspace_tensor.data(), nullptr); - - auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); - return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); -} - -void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { - const CustomCallFusedAttnDescriptor &descriptor = - *UnpackOpaque(opaque, opaque_len); - - // input buffers from XLA - void *qkv = buffers[0]; - void *bias = buffers[1]; - void *cu_seqlens = buffers[2]; - void *seed = buffers[3]; - - // output buffers from XLA - void *output = buffers[4]; - void *softmax_aux = buffers[5]; - void *rng_state = buffers[6]; - void *workspace = buffers[7]; - - // tensor sizes - auto input_batch = descriptor.input_batch; - auto bias_batch = descriptor.bias_batch; - auto max_seqlen = descriptor.q_max_seqlen; - auto attn_heads = descriptor.attn_heads; - auto bias_heads = descriptor.bias_heads; - auto head_dim = descriptor.head_dim; - auto dropout_probability = descriptor.dropout_probability; - auto bias_type = descriptor.bias_type; - auto mask_type = descriptor.mask_type; - - auto dtype = descriptor.dtype; - auto qkv_shape = std::vector{input_batch * max_seqlen, 3, attn_heads, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, max_seqlen, max_seqlen}; - - // input tensors - auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); - auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); - auto cu_seqlens_tensor = - TensorWrapper(cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); - - // output tensors - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in FP16/BF16 - auto o_tensor = TensorWrapper( - output, std::vector{input_batch * max_seqlen, attn_heads, head_dim}, dtype); - - // prep RNG state - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; - auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); - auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, attn_heads, max_seqlen, max_seqlen, head_dim); - PopulateRngStateAsync(rng_state, seed, max_seqlen, max_seqlen, backend, stream); - - // auxiliary tensors (to be propagated to the backward pass later) - NVTETensorPack aux_output_tensors; - nvte_tensor_pack_create(&aux_output_tensors); - PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend, - softmax_aux); - - // cuDNN workspace - auto wkspace_size = std::vector{descriptor.wkspace_size}; - auto wkspace_dtype = descriptor.wkspace_dtype; - auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); - - nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(), - rng_state_tensor.data(), max_seqlen, descriptor.is_training, - descriptor.scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, workspace_tensor.data(), stream); - - nvte_tensor_pack_destroy(&aux_output_tensors); -} - -pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes( - size_t input_batch, size_t bias_batch, size_t max_seqlen, - size_t attn_heads, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) { - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; - - auto qkv_shape = std::vector{input_batch * max_seqlen, 3, attn_heads, head_dim}; - auto output_shape = std::vector{input_batch * max_seqlen, attn_heads, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, max_seqlen, max_seqlen}; - - auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - auto output_tensor = TensorWrapper(nullptr, output_shape, dtype); - auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype); - // F16 doesn't use this tensor - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); - - auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype); - - auto cu_seqlens_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - - NVTETensorPack aux_input_tensors; - nvte_tensor_pack_create(&aux_input_tensors); - - TensorWrapper query_workspace_tensor; - nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - cu_seqlens_tensor.data(), max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, - query_workspace_tensor.data(), nullptr); - - auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); - return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); -} - -void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { - const CustomCallFusedAttnDescriptor &descriptor = - *UnpackOpaque(opaque, opaque_len); - - // input buffers from XLA - void *qkv = buffers[0]; - void *bias = buffers[1]; - void *softmax_aux = buffers[2]; - void *rng_state = buffers[3]; - void *output = buffers[4]; - void *doutput = buffers[5]; - void *cu_seqlens = buffers[6]; - - // output buffers from XLA - void *dqkv = buffers[7]; - void *dbias = buffers[8]; - void *workspace = buffers[9]; - - // tensor sizes - auto input_batch = descriptor.input_batch; - auto bias_batch = descriptor.bias_batch; - auto max_seqlen = descriptor.q_max_seqlen; - auto attn_heads = descriptor.attn_heads; - auto bias_heads = descriptor.bias_heads; - auto head_dim = descriptor.head_dim; - auto scaling_factor = descriptor.scaling_factor; - auto dropout_probability = descriptor.dropout_probability; - auto bias_type = descriptor.bias_type; - auto mask_type = descriptor.mask_type; - - auto dtype = descriptor.dtype; - auto qkv_shape = std::vector{input_batch * max_seqlen, 3, attn_heads, head_dim}; - auto output_shape = std::vector{input_batch * max_seqlen, attn_heads, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, max_seqlen, max_seqlen}; - - // input tensors - auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); - auto output_tensor = TensorWrapper(output, output_shape, dtype); - auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); - - // output tensors - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in FP16/BF16 - auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype); - auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); - auto cu_seqlens_tensor = - TensorWrapper(cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); - - // auxiliary tensors (propagated from the forward pass) - NVTETensorPack aux_input_tensors; - nvte_tensor_pack_create(&aux_input_tensors); - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; - auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, attn_heads, max_seqlen, max_seqlen, head_dim); - PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, - rng_state, bias); - - // cuDNN workspace - auto wkspace_size = std::vector{descriptor.wkspace_size}; - auto wkspace_dtype = descriptor.wkspace_dtype; - auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); - - nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - cu_seqlens_tensor.data(), max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, - workspace_tensor.data(), stream); - - nvte_tensor_pack_destroy(&aux_input_tensors); -} - -pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes( +pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) { - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; + float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) { + // For qkv_packed + auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; + auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); + // For kv_packed auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); + // For separate q, k, v + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; + auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); + auto v_shape = k_shape; + auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); + + auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype); - // FP16/BF16 doesn't use this tensor + // F16 doesn't use this tensor auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); auto o_tensor = TensorWrapper(nullptr, q_shape, dtype); @@ -1281,292 +1066,133 @@ pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes( nvte_tensor_pack_create(&aux_output_tensors); TensorWrapper query_workspace_tensor; - nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, - is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, query_workspace_tensor.data(), nullptr); - - auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); - return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); -} - -void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { - const CustomCallFusedAttnDescriptor &descriptor = - *UnpackOpaque(opaque, opaque_len); - - // input buffers from XLA - void *q = buffers[0]; - void *kv = buffers[1]; - void *bias = buffers[2]; - void *q_cu_seqlens = buffers[3]; - void *kv_cu_seqlens = buffers[4]; - void *seed = buffers[5]; - - // output buffers from XLA - void *output = buffers[6]; - void *softmax_aux = buffers[7]; - void *rng_state = buffers[8]; - void *workspace = buffers[9]; - - // tensor sizes - auto input_batch = descriptor.input_batch; - auto bias_batch = descriptor.bias_batch; - auto q_max_seqlen = descriptor.q_max_seqlen; - auto kv_max_seqlen = descriptor.kv_max_seqlen; - auto attn_heads = descriptor.attn_heads; - auto num_gqa_groups = descriptor.num_gqa_groups; - auto bias_heads = descriptor.bias_heads; - auto head_dim = descriptor.head_dim; - auto scaling_factor = descriptor.scaling_factor; - auto dropout_probability = descriptor.dropout_probability; - auto bias_type = descriptor.bias_type; - auto mask_type = descriptor.mask_type; - - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - - // input tensors - auto dtype = descriptor.dtype; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); - auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); - - // output tensors - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in FP16/BF16 - auto o_tensor = TensorWrapper(output, q_shape, dtype); - auto q_cu_seqlens_tensor = - TensorWrapper(q_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); - auto kv_cu_seqlens_tensor = - TensorWrapper(kv_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); - - // prep RNG state - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; - auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); - auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - head_dim); - PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); - - // auxiliary tensors (to be propagated to the backward pass later) - NVTETensorPack aux_output_tensors; - nvte_tensor_pack_create(&aux_output_tensors); - PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend, - softmax_aux); - - // cuDNN workspace - auto workspace_tensor = TensorWrapper(workspace, std::vector{descriptor.wkspace_size}, - descriptor.wkspace_dtype); - - nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, - descriptor.is_training, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream); + if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { + assert(q_max_seqlen == kv_max_seqlen); + nvte_fused_attn_fwd_qkvpacked( + qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), + &aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_rng_state_tensor.data(), + q_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, + mask_type, query_workspace_tensor.data(), nullptr); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { + nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, + is_training, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, query_workspace_tensor.data(), nullptr); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + query_workspace_tensor.data(), nullptr); + } else { + NVTE_ERROR("Unsupported QKVLayout."); + } - nvte_tensor_pack_destroy(&aux_output_tensors); + auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape()); + return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype()); } -pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes( - size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) { - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; - - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - - auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); - auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype); +pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( + size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, + size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, + bool is_training) { + auto output_shape = std::vector{batch_size * q_max_seqlen, attn_heads, head_dim}; auto output_tensor = TensorWrapper(nullptr, output_shape, dtype); - // FP16/BF16 doesn't use this tensor - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype); - auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype); + auto bias_shape = std::vector{1, attn_heads, q_max_seqlen, kv_max_seqlen}; auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype); - auto q_cu_seqlens_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - auto kv_cu_seqlens_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - - NVTETensorPack aux_input_tensors; - nvte_tensor_pack_create(&aux_input_tensors); - - TensorWrapper query_workspace_tensor; - nvte_fused_attn_bwd_kvpacked( - q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for FP16/BF16 - s_tensor.data(), // not used for FP16/BF16 - &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - query_workspace_tensor.data(), nullptr); - - auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); - return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); -} - -void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { - const CustomCallFusedAttnDescriptor &descriptor = - *UnpackOpaque(opaque, opaque_len); - - // input buffers from XLA - void *q = buffers[0]; - void *kv = buffers[1]; - void *bias = buffers[2]; - void *softmax_aux = buffers[3]; - void *rng_state = buffers[4]; - void *output = buffers[5]; - void *doutput = buffers[6]; - void *q_cu_seqlens = buffers[7]; - void *kv_cu_seqlens = buffers[8]; - - // output buffers from XLA - void *dq = buffers[9]; - void *dkv = buffers[10]; - void *dbias = buffers[11]; - void *workspace = buffers[12]; - - // tensor sizes - auto input_batch = descriptor.input_batch; - auto bias_batch = descriptor.bias_batch; - auto q_max_seqlen = descriptor.q_max_seqlen; - auto kv_max_seqlen = descriptor.kv_max_seqlen; - auto attn_heads = descriptor.attn_heads; - auto num_gqa_groups = descriptor.num_gqa_groups; - auto bias_heads = descriptor.bias_heads; - auto head_dim = descriptor.head_dim; - auto scaling_factor = descriptor.scaling_factor; - auto dropout_probability = descriptor.dropout_probability; - auto bias_type = descriptor.bias_type; - auto mask_type = descriptor.mask_type; - - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - - // input tensors - auto dtype = descriptor.dtype; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); - auto output_tensor = TensorWrapper(output, output_shape, dtype); - auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); + // F16 doesn't use s_tensor + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); - // output tensors - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in FP16/BF16 - auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype); - auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); auto q_cu_seqlens_tensor = - TensorWrapper(q_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); + TensorWrapper(nullptr, std::vector{batch_size + 1}, DType::kInt32); auto kv_cu_seqlens_tensor = - TensorWrapper(kv_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); + TensorWrapper(nullptr, std::vector{batch_size + 1}, DType::kInt32); - // auxiliary tensors (propagated from the forward pass) NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; - auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - head_dim); - PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, - rng_state, bias); - - // cuDNN workspace - auto wkspace_size = std::vector{descriptor.wkspace_size}; - auto wkspace_dtype = descriptor.wkspace_dtype; - auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); - - nvte_fused_attn_bwd_kvpacked( - q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for FP16/BF16 - s_tensor.data(), // not used for FP16/BF16 - &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - workspace_tensor.data(), stream); - - nvte_tensor_pack_destroy(&aux_input_tensors); -} - -pybind11::tuple GetFusedAttnForwardWorkspaceSizes( - size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) { - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD; - - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; - auto v_shape = k_shape; - auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - - auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); - auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); - - auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype); - - // F16 doesn't use this tensor - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); - auto o_tensor = TensorWrapper(nullptr, q_shape, dtype); - - auto q_cu_seqlens_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - auto kv_cu_seqlens_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - - auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector{2}, DType::kInt64); - - NVTETensorPack aux_output_tensors; - nvte_tensor_pack_create(&aux_output_tensors); TensorWrapper query_workspace_tensor; - nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - query_workspace_tensor.data(), nullptr); - auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); - return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); + if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { + assert(q_max_seqlen == kv_max_seqlen); + auto qkv_shape = std::vector{batch_size * q_max_seqlen, 3, attn_heads, head_dim}; + auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); + auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); + nvte_fused_attn_bwd_qkvpacked( + qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), + q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + query_workspace_tensor.data(), nullptr); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { + auto q_shape = std::vector{batch_size * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto kv_shape = + std::vector{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim}; + auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); + auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype); + nvte_fused_attn_bwd_kvpacked( + q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + query_workspace_tensor.data(), nullptr); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + auto q_shape = std::vector{batch_size * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto k_shape = std::vector{batch_size * kv_max_seqlen, num_gqa_groups, head_dim}; + auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); + auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype); + auto v_shape = k_shape; + auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); + auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype); + nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), + dv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + query_workspace_tensor.data(), nullptr); + } else { + NVTE_ERROR("Unsupported QKVLayout."); + } + + auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape()); + return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype()); } void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { const CustomCallFusedAttnDescriptor &descriptor = *UnpackOpaque(opaque, opaque_len); - // input buffers from XLA - void *q = buffers[0]; - void *k = buffers[1]; - void *v = buffers[2]; + /* Input buffers from XLA */ + /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ void *bias = buffers[3]; void *q_cu_seqlens = buffers[4]; void *kv_cu_seqlens = buffers[5]; void *seed = buffers[6]; - // output buffers from XLA + /* Output buffer from XLA */ void *output = buffers[7]; void *softmax_aux = buffers[8]; void *rng_state = buffers[9]; void *workspace = buffers[10]; - // tensor sizes + /* Descriptor */ auto input_batch = descriptor.input_batch; auto bias_batch = descriptor.bias_batch; auto q_max_seqlen = descriptor.q_max_seqlen; @@ -1579,29 +1205,26 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s auto dropout_probability = descriptor.dropout_probability; auto bias_type = descriptor.bias_type; auto mask_type = descriptor.mask_type; + auto qkv_layout = descriptor.qkv_layout; + auto dtype = descriptor.dtype; + /* Input tensors */ auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; auto v_shape = k_shape; auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - - // input tensors - auto dtype = descriptor.dtype; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v_tensor = TensorWrapper(v, v_shape, dtype); auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); - // output tensors + /* Output tensors */ auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 - auto o_tensor = TensorWrapper(output, q_shape, dtype); + auto o_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto o_tensor = TensorWrapper(output, o_shape, dtype); auto q_cu_seqlens_tensor = TensorWrapper(q_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); auto kv_cu_seqlens_tensor = TensorWrapper(kv_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); - // prep RNG state - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD; + /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); auto backend = nvte_get_fused_attn_backend( static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, @@ -1609,22 +1232,59 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s head_dim); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); - // auxiliary tensors (to be propagated to the backward pass later) + /* Auxiliary tensors (to be propagated to the backward pass later) */ NVTETensorPack aux_output_tensors; nvte_tensor_pack_create(&aux_output_tensors); PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend, softmax_aux); - // cuDNN workspace + /* cuDNN workspace */ auto workspace_tensor = TensorWrapper(workspace, std::vector{descriptor.wkspace_size}, descriptor.wkspace_dtype); - nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, - descriptor.is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, workspace_tensor.data(), stream); + /* Call the underly NVTE API */ + if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { + auto qkv = buffers[0]; + auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; + auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); + nvte_fused_attn_fwd_qkvpacked( + qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), + &aux_output_tensors, q_cu_seqlens_tensor.data(), rng_state_tensor.data(), q_max_seqlen, + descriptor.is_training, descriptor.scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, workspace_tensor.data(), stream); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { + auto q = buffers[0]; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto kv = buffers[1]; + auto kv_shape = + std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; + auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); + nvte_fused_attn_fwd_kvpacked( + q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), + &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + workspace_tensor.data(), stream); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + auto q = buffers[0]; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto k = buffers[1]; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; + auto k_tensor = TensorWrapper(k, k_shape, dtype); + auto v = buffers[2]; + auto v_shape = k_shape; + auto v_tensor = TensorWrapper(v, v_shape, dtype); + nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, + descriptor.is_training, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, workspace_tensor.data(), stream); + } else { + NVTE_ERROR("Unsupported qkv_layout."); + } nvte_tensor_pack_destroy(&aux_output_tensors); } @@ -1632,10 +1292,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) { - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD; - + float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; auto v_shape = k_shape; @@ -1682,10 +1340,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, const CustomCallFusedAttnDescriptor &descriptor = *UnpackOpaque(opaque, opaque_len); - // input buffers from XLA - void *q = buffers[0]; - void *k = buffers[1]; - void *v = buffers[2]; + /* Input buffers from XLA */ + /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ void *bias = buffers[3]; void *softmax_aux = buffers[4]; void *rng_state = buffers[5]; @@ -1694,14 +1350,12 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, void *q_cu_seqlens = buffers[8]; void *kv_cu_seqlens = buffers[9]; - // output buffers from XLA - void *dq = buffers[10]; - void *dk = buffers[11]; - void *dv = buffers[12]; + /* Output buffer from XLA */ + /* Buffers[10-12] are dq, dk, dv, which are parsed later for different qkv_layout */ void *dbias = buffers[13]; void *workspace = buffers[14]; - // tensor sizes + /* Descriptor */ auto input_batch = descriptor.input_batch; auto bias_batch = descriptor.bias_batch; auto q_max_seqlen = descriptor.q_max_seqlen; @@ -1714,36 +1368,26 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, auto dropout_probability = descriptor.dropout_probability; auto bias_type = descriptor.bias_type; auto mask_type = descriptor.mask_type; + auto qkv_layout = descriptor.qkv_layout; + auto dtype = descriptor.dtype; - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; - auto v_shape = k_shape; + /* Input tensors */ auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - - // input tensors - auto dtype = descriptor.dtype; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v_tensor = TensorWrapper(v, v_shape, dtype); auto output_tensor = TensorWrapper(output, output_shape, dtype); auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); - // output tensors + /* Output tensors */ auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 - auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dk_tensor = TensorWrapper(dk, k_shape, dtype); - auto dv_tensor = TensorWrapper(dv, v_shape, dtype); auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); auto q_cu_seqlens_tensor = TensorWrapper(q_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); auto kv_cu_seqlens_tensor = TensorWrapper(kv_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); - // auxiliary tensors (propagated from the forward pass) + /* Auxiliary tensors (propagated from the forward pass) */ NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD; auto backend = nvte_get_fused_attn_backend( static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, @@ -1751,20 +1395,73 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, rng_state, bias); - // cuDNN workspace + /* cuDNN workspace */ auto wkspace_size = std::vector{descriptor.wkspace_size}; auto wkspace_dtype = descriptor.wkspace_dtype; auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); - nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, - workspace_tensor.data(), stream); + /* Call the underly NVTE API */ + if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { + auto qkv = buffers[0]; + auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; + auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); + auto dqkv = buffers[10]; + auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype); + nvte_fused_attn_bwd_qkvpacked( + qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), + q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + workspace_tensor.data(), stream); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { + auto q = buffers[0]; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto kv = buffers[1]; + auto kv_shape = + std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; + auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); + auto dq = buffers[10]; + auto dq_tensor = TensorWrapper(dq, q_shape, dtype); + auto dkv = buffers[11]; + auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype); + nvte_fused_attn_bwd_kvpacked( + q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + workspace_tensor.data(), stream); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + auto q = buffers[0]; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto k = buffers[1]; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; + auto k_tensor = TensorWrapper(k, k_shape, dtype); + auto v = buffers[2]; + auto v_shape = k_shape; + auto v_tensor = TensorWrapper(v, v_shape, dtype); + auto dq = buffers[10]; + auto dq_tensor = TensorWrapper(dq, q_shape, dtype); + auto dk = buffers[11]; + auto dk_tensor = TensorWrapper(dk, k_shape, dtype); + auto dv = buffers[12]; + auto dv_tensor = TensorWrapper(dv, v_shape, dtype); + nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), + dv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + workspace_tensor.data(), stream); + } else { + NVTE_ERROR("Unsupported qkv_layout."); + } nvte_tensor_pack_destroy(&aux_input_tensors); } diff --git a/transformer_engine/jax/csrc/modules.h b/transformer_engine/jax/csrc/modules.h index a2b873235e..e392931d04 100644 --- a/transformer_engine/jax/csrc/modules.h +++ b/transformer_engine/jax/csrc/modules.h @@ -118,17 +118,18 @@ struct CustomCallFusedAttnDescriptor { float dropout_probability; NVTE_Bias_Type bias_type; NVTE_Mask_Type mask_type; + NVTE_QKV_Layout qkv_layout; DType dtype; DType wkspace_dtype; bool is_training; }; pybind11::bytes PackCustomCallFusedAttnDescriptor( - size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - size_t wkspace_size, float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - DType dtype, DType wkspace_dtype, bool is_training); + size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, + bool is_training); NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -207,55 +208,19 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len); -pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes( - size_t input_batch, size_t bias_batch, size_t max_seqlen, - size_t attn_heads, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); - -void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len); - -pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes( - size_t input_batch, size_t bias_batch, size_t max_seqlen, - size_t attn_heads, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); - -void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len); - -pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes( - size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); - -void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len); - -pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes( - size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); - -void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len); - pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); + float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training); void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); + float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training); void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 6979cffd90..fcf06aa128 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -26,7 +26,7 @@ from .module import LayerNorm, Softmax from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout from ..fused_attn import is_fused_attn_kernel_available, canonicalize_attn_mask_type -from ..fused_attn import self_fused_attn, cross_fused_attn, fused_attn +from ..fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn from ..softmax import SoftmaxType from ..sharding import num_of_devices from ..sharding import get_sharding_map_logic_axis_to_mesh_axis @@ -190,16 +190,19 @@ def __call__(self, def convert_to_softmax_type(attn_mask_type, mask): """Convert the attn_mask_type to SoftmaxType""" + # mask is ignored for no_mask and causal_mask + if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: + mask = None if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: - return SoftmaxType.SCALED_UPPER_TRIANG_MASKED + return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]: if mask is not None: - return SoftmaxType.SCALED_MASKED - return SoftmaxType.SCALED - raise ValueError(f"Unsupported {attn_mask_type=}, " - "supported attn_mask_type = {'causal', 'padding'}") + return SoftmaxType.SCALED_MASKED, mask + return SoftmaxType.SCALED, mask + raise ValueError(f"Unsupported {attn_mask_type=}, supported attn_mask_type=" + "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}") - softmax_type = convert_to_softmax_type(self.attn_mask_type, mask) + softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask) attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)(attn_weights, mask, @@ -266,15 +269,15 @@ def __call__(self, qkv_packed = query if self.transpose_batch_sequence: qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4]) - x = self_fused_attn(qkv_packed, - bias, - mask, - seed, - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, - scaling_factor=scale_factor, - dropout_probability=self.attention_dropout, - is_training=not deterministic) + x = fused_attn_qkvpacked(qkv_packed, + bias, + mask, + seed, + attn_mask_type=self.attn_mask_type, + attn_bias_type=self.attn_bias_type, + scaling_factor=scale_factor, + dropout_probability=self.attention_dropout, + is_training=not deterministic) elif self.qkv_layout == QKVLayout.BSHD_BS2HD: """kvpacked format, treat query: query tensor, shape = [..., h, d] @@ -285,16 +288,16 @@ def __call__(self, if self.transpose_batch_sequence: query = query.transpose([1, 0, 2, 3]) kv_packed = kv_packed.transpose([1, 0, 2, 3, 4]) - x = cross_fused_attn(query, - kv_packed, - bias, - mask, - seed, - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, - scaling_factor=scale_factor, - dropout_probability=self.attention_dropout, - is_training=not deterministic) + x = fused_attn_kvpacked(query, + kv_packed, + bias, + mask, + seed, + attn_mask_type=self.attn_mask_type, + attn_bias_type=self.attn_bias_type, + scaling_factor=scale_factor, + dropout_probability=self.attention_dropout, + is_training=not deterministic) elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD: if self.transpose_batch_sequence: query = query.transpose([1, 0, 2, 3]) @@ -358,11 +361,27 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method attention_dropout: float, default = 0.0 Dropout probability for the dropout op after the softmax. attn_mask_type: str, default = 'causal' - Type of the attention mask passed into softmax operation in the self attention. - Available options: {'no_mask', 'padding', 'causal', 'causal_padding'} - Introduced in v0.10.0. + This parameter specifies the type of attention mask to be applied during the softmax + operation. + Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'} + + Each described below: + + * no_mask: No attention mask is applied. This means the attention will consider the + full sequence without any restrictions. + * padding: Indicates the presence of padding at the end of each sequence. + Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the + :attr:`__call__` method to specify the padding positions. + * causal: An upper triangular mask is applied to the softmax inputs, + ensuring that the prediction for a certain position is only dependent on known outputs + from positions before it. + * causal_padding / padding_causal: A combination of both causal and padding masks. + Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. + + .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. + attn_bias_type: Optional[str], default = None - Type of the attention bias passed in the self attention. + Type of the attention bias passed in the attention. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. When default is present, the type is automatically decided by the MHA's bias parameter. Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used. @@ -438,6 +457,7 @@ def __call__(self, mask: jax.numpy.ndarray, default = None Boolean tensor used to mask out the attention softmax input. :attr:`True` means to mask out the corresponding values. + Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'. bias: jax.numpy.ndarray, default = None A tensor used to shift attention softmax input. *: @@ -639,9 +659,25 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods attention_dropout: float, default = 0.0 Dropout probability for the dropout op after the softmax. attn_mask_type: str, default = 'causal' - Type of the attention mask passed into softmax operation in the attention. - Available options: {'no_mask', 'padding', 'causal', 'causal_padding'} - Introduced in v0.10.0. + This parameter specifies the type of attention mask to be applied during the softmax + operation. + Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'} + + Each described below: + + * no_mask: No attention mask is applied. This means the attention will consider the + full sequence without any restrictions. + * padding: Indicates the presence of padding at the end of each sequence. + Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the + :attr:`__call__` method to specify the padding positions. + * causal: An upper triangular mask is applied to the softmax inputs, + ensuring that the prediction for a certain position is only dependent on known outputs + from positions before it. + * causal_padding / padding_causal: A combination of both causal and padding masks. + Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. + + .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. + attn_bias_type: Optional[str], default = None Type of the attention bias passed in the attention. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. @@ -809,6 +845,7 @@ def __call__(self, mask: jax.numpy.ndarray, default = None Boolean tensor used to mask out the attention softmax input. :attr:`True` means mask out the corresponding values. + Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'. bias: jax.numpy.ndarray, default = None A tensor used to shift the attention softmax input. * @@ -1299,9 +1336,25 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods is added after self-attention.this can be used for structures like `T5` Transformer in conjunction with the TransformerLayerType.ENCODER option. self_attn_mask_type: str, default = 'causal' - Type of the attention mask passed into softmax operation in the self attention. - Available options: {'no_mask', 'padding', 'causal', 'causal_padding'} - Introduced in v0.10.0. + This parameter specifies the type of attention mask to be applied during the softmax + operation in the self attention. + Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'} + + Each described below: + + * no_mask: No attention mask is applied. This means the self attention will consider the + full sequence without any restrictions. + * padding: Indicates the presence of padding at the end of each sequence. + Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the + :attr:`__call__` method to specify the padding positions. + * causal: An upper triangular mask is applied to the softmax inputs, + ensuring that the prediction for a certain position is only dependent on known outputs + from positions before it. + * causal_padding / padding_causal: A combination of both causal and padding masks. + Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. + + .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. + self_attn_bias_type: Optional[str], default = None Type of the attention bias passed into the self attention. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. @@ -1420,9 +1473,12 @@ def __call__(self, :attr:`layer_type=TransformerLayerType.DECODER`. attention_mask : jax.numpy.ndarray, default = None Boolean tensor used to mask out self-attention softmax input. + :attr:`True` means mask out the corresponding values. + Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'. encoder_decoder_mask: jax.numpy.ndarray, default = None Boolean tensor used to mask out cross-attention softmax input when :attr:`layer_type=TransformerLayerType.DECODER`. + :attr:`True` means mask out the corresponding values. deterministic: bool, default = False Disable dropout layers if set to True. decode: bool, default = False diff --git a/transformer_engine/jax/fused_attn.py b/transformer_engine/jax/fused_attn.py index 008f13ef4a..8b32163811 100644 --- a/transformer_engine/jax/fused_attn.py +++ b/transformer_engine/jax/fused_attn.py @@ -14,20 +14,29 @@ from transformer_engine_jax import NVTE_QKV_Layout from .cpp_extensions import FusedAttnHelper -from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd -from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd +from .cpp_extensions import fused_attn_fwd_kvpacked, fused_attn_bwd_kvpacked +from .cpp_extensions import fused_attn_fwd_qkvpacked, fused_attn_bwd_qkvpacked from .cpp_extensions import fused_attn_fwd, fused_attn_bwd class AttnBiasType(Enum): - """Attention Bias Type.""" + """ + NO_BIAS: Softmax is performed as softmax(scale * qk) + PRE_SCALE_BIAS: Softmax is performed as softmax(scale * (qk + bias)) + POST_SCALE_BIAS: Softmax is performed as softmax(scale * qk + bias) + """ NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS class AttnMaskType(Enum): - """Attention Mask Type.""" + """ + NO_MASK: No attention mask is applied. + PADDING_MASK: Indicates the presence of paddings at the end of each sequence. + CAUSAL_MASK: An upper triangular mask is applied to the softmax inputs. + PADDING_CAUSAL_MASK: A combination of both causal and padding masks. + """ NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK @@ -47,99 +56,105 @@ def canonicalize_attn_mask_type(attn_mask_type: str): The overhead between padding and non-padding version should be small. However, we will lease this limitation in the near feature. """ - if attn_mask_type in ['causal', 'padding_causal']: - return AttnMaskType.PADDING_CAUSAL_MASK - if attn_mask_type in ['no_mask', 'padding']: - return AttnMaskType.PADDING_MASK - raise ValueError(f"Unsupported {attn_mask_type=}, " - "supported attn_mask_type={'no_mask', 'padding', 'causal', 'padding_causal'}") - - -def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type, - dropout_probability, num_heads_q, num_heads_kv, max_seqlen_q, - max_seqlen_kv, head_dim): + match attn_mask_type: + case 'no_mask': + return AttnMaskType.NO_MASK + case 'padding': + return AttnMaskType.PADDING_MASK + case 'causal': + return AttnMaskType.CAUSAL_MASK + case 'padding_causal' | 'causal_padding': + return AttnMaskType.PADDING_CAUSAL_MASK + raise ValueError(f"Unsupported {attn_mask_type=}, supported attn_mask_type=" + "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}") + + +def is_fused_attn_kernel_available(q_dtype, kv_dtype, qkv_layout, attn_bias_type, attn_mask_type, + dropout_probability, q_num_heads, kv_num_heads, q_max_seqlen, + kv_max_seqlen, head_dim): """ - To check whether the fused attention kernel is available + To check whether the fused attention kernel is supported """ - return FusedAttnHelper(q_type, kv_type, qkv_layout.value, attn_bias_type.value, - attn_mask_type.value, dropout_probability, num_heads_q, num_heads_kv, - max_seqlen_q, max_seqlen_kv, head_dim).is_fused_attn_kernel_available() + return FusedAttnHelper(q_dtype, kv_dtype, qkv_layout.value, attn_bias_type.value, + attn_mask_type.value, dropout_probability, q_num_heads, kv_num_heads, + q_max_seqlen, kv_max_seqlen, head_dim).is_fused_attn_kernel_available() -def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray, - seed: jnp.ndarray | None, attn_bias_type: AttnBiasType, - attn_mask_type: AttnMaskType, scaling_factor: float, - dropout_probability: float, is_training: bool): +def fused_attn_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray, + seed: jnp.ndarray | None, attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, scaling_factor: float, + dropout_probability: float, is_training: bool): """ - Self fused attention wrapper + Fused attention with the qkvpacked inputs """ - output = _self_fused_attn(qkv, - bias, - mask, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + output = _fused_attn_qkvpacked(qkv, + bias, + mask, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) return output @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8)) -def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray, - seed: jnp.ndarray | None, attn_bias_type: AttnBiasType, - attn_mask_type: AttnMaskType, scaling_factor: float, - dropout_probability: float, is_training: bool): - - output, _ = _self_fused_attn_fwd_rule(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, - scaling_factor, dropout_probability, is_training) +def _fused_attn_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray, + seed: jnp.ndarray | None, attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, scaling_factor: float, + dropout_probability: float, is_training: bool): + + output, _ = _fused_attn_fwd_qkvpacked_rule(qkv, bias, mask, seed, attn_bias_type, + attn_mask_type, scaling_factor, dropout_probability, + is_training) return output -def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None, - mask: jnp.ndarray, seed: jnp.ndarray | None, - attn_bias_type: AttnBiasType, - attn_mask_type: AttnMaskType, - scaling_factor: float, dropout_probability: float, - is_training: bool): - if mask is None: +def _fused_attn_fwd_qkvpacked_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray, + seed: jnp.ndarray | None, attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, scaling_factor: float, + dropout_probability: float, is_training: bool): + if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: batch, seqlen, *_ = qkv.shape actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32) else: + assert mask is not None mask = jnp.logical_not(mask) actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) - output, softmax_aux, rng_state = self_fused_attn_fwd(qkv, - bias, - actual_seqlen, - seed, - attn_bias_type=attn_bias_type.value, - attn_mask_type=attn_mask_type.value, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + output, softmax_aux, rng_state = fused_attn_fwd_qkvpacked( + qkv, + bias, + actual_seqlen, + seed, + attn_bias_type=attn_bias_type.value, + attn_mask_type=attn_mask_type.value, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) output = checkpoint_name(output, 'context') softmax_aux = checkpoint_name(softmax_aux, 'context') rng_state = checkpoint_name(rng_state, 'context') return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen) -def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, - is_training, ctx, dz): +def _fused_attn_bwd_qkvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training, ctx, dz): qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx - grad_qkv, grad_bias = self_fused_attn_bwd(qkv, - bias, - softmax_aux, - rng_state, - output, - dz, - actual_seqlen, - attn_bias_type=attn_bias_type.value, - attn_mask_type=attn_mask_type.value, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + grad_qkv, grad_bias = fused_attn_bwd_qkvpacked(qkv, + bias, + softmax_aux, + rng_state, + output, + dz, + actual_seqlen, + attn_bias_type=attn_bias_type.value, + attn_mask_type=attn_mask_type.value, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) if attn_bias_type == AttnBiasType.NO_BIAS: grad_bias = None @@ -147,91 +162,96 @@ def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dr return grad_qkv, grad_bias, None, None -_self_fused_attn.defvjp(_self_fused_attn_fwd_rule, _self_fused_attn_bwd_rule) +_fused_attn_qkvpacked.defvjp(_fused_attn_fwd_qkvpacked_rule, _fused_attn_bwd_qkvpacked_rule) -def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, - seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, - scaling_factor: float, dropout_probability: float, is_training: bool): +def fused_attn_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, + seed: jnp.ndarray, attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, scaling_factor: float, + dropout_probability: float, is_training: bool): """ - Cross multi-head attention wrapper + Fused attention with the kvpacked inputs """ - output = _cross_fused_attn(q, - kv, - bias, - mask, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + output = _fused_attn_kvpacked(q, + kv, + bias, + mask, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) return output @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) -def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, - seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, - scaling_factor: float, dropout_probability: float, is_training: bool): - - output, _ = _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type, - scaling_factor, dropout_probability, is_training) +def _fused_attn_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, + seed: jnp.ndarray, attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, scaling_factor: float, + dropout_probability: float, is_training: bool): + + output, _ = _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type, + attn_mask_type, scaling_factor, dropout_probability, + is_training) return output -def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type, - scaling_factor, dropout_probability, is_training): - if mask is None: +def _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type, + scaling_factor, dropout_probability, is_training): + if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: batch, s_q, *_ = q.shape s_kv = kv.shape[1] q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32) kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32) else: + assert mask is not None mask = jnp.logical_not(mask) q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) - if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: + if attn_mask_type == AttnMaskType.PADDING_MASK: kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,) else: # When mask is causal, the actual seqlen is not the last row, use max to find it kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2)) - output, softmax_aux, rng_state = cross_fused_attn_fwd(q, - kv, - bias, - q_actual_seqlen, - kv_actual_seqlen, - seed, - attn_bias_type=attn_bias_type.value, - attn_mask_type=attn_mask_type.value, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + output, softmax_aux, rng_state = fused_attn_fwd_kvpacked( + q, + kv, + bias, + q_actual_seqlen, + kv_actual_seqlen, + seed, + attn_bias_type=attn_bias_type.value, + attn_mask_type=attn_mask_type.value, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) output = checkpoint_name(output, 'context') softmax_aux = checkpoint_name(softmax_aux, 'context') rng_state = checkpoint_name(rng_state, 'context') return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen) -def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, - is_training, ctx, dz): +def _fused_attn_bwd_kvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training, ctx, dz): q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx - grad_q, grad_kv, grad_bias = cross_fused_attn_bwd(q, - kv, - bias, - softmax_aux, - rng_state, - output, - dz, - q_actual_seqlen, - kv_actual_seqlen, - attn_bias_type=attn_bias_type.value, - attn_mask_type=attn_mask_type.value, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + grad_q, grad_kv, grad_bias = fused_attn_bwd_kvpacked(q, + kv, + bias, + softmax_aux, + rng_state, + output, + dz, + q_actual_seqlen, + kv_actual_seqlen, + attn_bias_type=attn_bias_type.value, + attn_mask_type=attn_mask_type.value, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) if attn_bias_type == AttnBiasType.NO_BIAS: grad_bias = None @@ -239,7 +259,7 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d return grad_q, grad_kv, grad_bias, None, None -_cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule) +_fused_attn_kvpacked.defvjp(_fused_attn_fwd_kvpacked_rule, _fused_attn_bwd_kvpacked_rule) def fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, @@ -277,15 +297,16 @@ def _fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarra def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training): - if mask is None: + if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: batch, s_q, *_ = q.shape s_kv = k.shape[1] q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32) kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32) else: + assert mask is not None mask = jnp.logical_not(mask) q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) - if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: + if attn_mask_type == AttnMaskType.PADDING_MASK: kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,) else: # When mask is causal, the actual seqlen is not the last row, use max to find it From c1a68f6cb40a73d7e7bf77aa73c2820ffaf47281 Mon Sep 17 00:00:00 2001 From: Jaemin Choi Date: Fri, 22 Mar 2024 11:01:14 -0700 Subject: [PATCH 11/16] Enable TP-AG overlap with return_layernorm_output (#727) * Enable TP-AG overlap with return_layernorm_output Signed-off-by: Jaemin Choi * Use ub_overlap_ag Signed-off-by: Jaemin Choi --------- Signed-off-by: Jaemin Choi Co-authored-by: Jaemin Choi --- .../pytorch/module/layernorm_linear.py | 34 ++++++++++++++----- 1 file changed, 25 insertions(+), 9 deletions(-) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 3711d9898f..57c7e75e9e 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -107,13 +107,18 @@ def forward( if ub_overlap_ag: tp_world_size = get_distributed_world_size(tp_group) - if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: + if tp_world_size == 1 or (not is_grad_enabled): ub_overlap_ag = False if ub_overlap_ag: dim_size = list(inputmat.size()) dim_size[0] = dim_size[0] * tp_world_size ub_obj_lnout = get_ub(ub_name+"_fprop") - ln_out = ub_obj_lnout.get_ubuf_output(0) + if return_layernorm_output: + # First prepare LN output in higher precision, + # which will be later copied to a FP8 UB + ln_out = torch.empty_like(inputmat) + else: + ln_out = ub_obj_lnout.get_ubuf_output(0) else: ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) @@ -136,7 +141,8 @@ def forward( ln_out_gathered = False if ub_overlap_ag: ln_out_total = ub_obj_lnout.get_ubuf_output(1) - ln_out = torch.empty_like(ln_out) + if not return_layernorm_output: + ln_out = torch.empty_like(ln_out) if ub_obj_lnout.is_atomic_gemm(): ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P else: @@ -153,12 +159,22 @@ def forward( if return_layernorm_output: ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out if fp8: - ln_out = tex.cast_to_fp8( - ln_out, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) + if ub_overlap_ag: + ln_out_fp8 = ub_obj_lnout.get_ubuf_output(0) + tex.cast_to_fp8( + ln_out, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + out=ln_out_fp8) + ln_out = ln_out_fp8 + else: + ln_out = tex.cast_to_fp8( + ln_out, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) if fp8: bias_dtype = ( From df1b16dae798ab69fc3b50eb77411a2e12c190f7 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Thu, 28 Mar 2024 21:49:01 -0700 Subject: [PATCH 12/16] [PyTorch] Fix bug in FP8 cast in LayerNormLinear/LayerNormMLP (#738) Perform FP8 cast on gathered layernorm output in LayerNormLinear Signed-off-by: Tim Moon --- .../pytorch/cpp_extensions/gemm.py | 2 ++ .../pytorch/module/layernorm_linear.py | 11 ++++++-- .../pytorch/module/layernorm_mlp.py | 27 ++++++++++++++----- 3 files changed, 32 insertions(+), 8 deletions(-) diff --git a/transformer_engine/pytorch/cpp_extensions/gemm.py b/transformer_engine/pytorch/cpp_extensions/gemm.py index df571a0e6b..c270fef652 100644 --- a/transformer_engine/pytorch/cpp_extensions/gemm.py +++ b/transformer_engine/pytorch/cpp_extensions/gemm.py @@ -44,6 +44,8 @@ def fp8_gemm( assert fp8_meta_tensor is not None and out_index is not None assert_dim_for_fp8_exec(A) assert_dim_for_fp8_exec(B) + assert A.dtype == torch.uint8 + assert B.dtype == torch.uint8 if out is None: out = torch.empty( diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 57c7e75e9e..551c070eb9 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -169,12 +169,19 @@ def forward( out=ln_out_fp8) ln_out = ln_out_fp8 else: - ln_out = tex.cast_to_fp8( - ln_out, + ln_out_total = tex.cast_to_fp8( + ln_out_total, fp8_meta["scaling_fwd"], tex.FP8FwdTensors.GEMM1_INPUT, fp8_dtype_forward, ) + if ln_out_gathered: + rank = torch.distributed.get_rank(tp_group) + slice_start = rank * ln_out.size(0) + slice_end = (rank + 1) * ln_out.size(0) + ln_out = ln_out_total[slice_start:slice_end, ...] + else: + ln_out = ln_out_total if fp8: bias_dtype = ( diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index 7d86658260..979c3068f5 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -187,12 +187,27 @@ def forward( if return_layernorm_output: ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out if fp8: - ln_out = tex.cast_to_fp8( - ln_out, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) + if ub_overlap_ag: + ln_out = tex.cast_to_fp8( + ln_out, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) + else: + ln_out_total = tex.cast_to_fp8( + ln_out_total, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) + if ln_out_gathered: + rank = torch.distributed.get_rank(tp_group) + slice_start = rank * ln_out.size(0) + slice_end = (rank + 1) * ln_out.size(0) + ln_out = ln_out_total[slice_start:slice_end, ...] + else: + ln_out = ln_out_total if fp8: bias_dtype = ( From 12cbd863a91e5dac5dc443dbf48618633e68533c Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Thu, 28 Mar 2024 21:53:59 -0700 Subject: [PATCH 13/16] [PyTorch] Fix backward compatibility with checkpoint API (#740) * Fix backward compatibility with checkpoint API Signed-off-by: Kirthi Shankar Sivamani * review comments and fix lint Signed-off-by: Kirthi Shankar Sivamani --------- Signed-off-by: Kirthi Shankar Sivamani --- transformer_engine/pytorch/distributed.py | 37 +++++++++++++++++------ 1 file changed, 27 insertions(+), 10 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 67fc4db0d0..6a2a801efd 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -516,6 +516,12 @@ def checkpoint( kwargs : dict dictionary of string keys for keyword arguments to :attr:`function`. """ + only_tensor_args = True + for arg in args: + if not isinstance(arg, torch.Tensor): + only_tensor_args = False + break + # Pop out te.distributed.checkpoint() arguments global _USE_REENTRANT_ACTIVATION_RECOMPUTE _USE_REENTRANT_ACTIVATION_RECOMPUTE = kwargs.pop("use_reentrant", True) @@ -523,6 +529,27 @@ def checkpoint( tp_group = kwargs.pop("tp_group", None) get_rng_state_tracker = kwargs.pop("get_rng_state_tracker", None) + # Ensure backward compatibility. + if not only_tensor_args: + warnings.warn( + "Passing non-tensor non-keyword arguments is deprecated and support will be removed in " + "future releases of TransformerEngine. `distribute_saved_activations`, `tp_group`, and " + "`get_rng_state_tracker` must be passed as keyword arguments to `checkpoint`.", + DeprecationWarning, stacklevel=2, + ) + assert len(args) > 3, "Incorrect number of arguments for deprecated `checkpoint` API." + assert ( + isinstance(args[0], bool) and callable(args[1]) + and isinstance(args[2], None | dist_group_type) + ), "Incorrect arguments for deprecated `checkpoint` API." + for arg in args[3:]: + assert ( + isinstance(arg, None | torch.Tensor) + ), f"Expected tensor argument, found {type(arg)}." + + distribute_saved_activations, get_rng_state_tracker, tp_group = args[:3] # pylint: disable=unbalanced-tuple-unpacking + args = args[3:] + # Trigger the native PyTorch checkpoint if: # 1. `function` is a `torch.nn.Module` # AND @@ -555,16 +582,6 @@ def checkpoint( assert torch.distributed.is_initialized(), "torch.distributed is not initialized." tp_group = torch.distributed.GroupMember.WORLD if tp_group is None else tp_group - # Make sure at least one tensor input has `requires_grad=True` - input_requires_grad = False - for arg in args: - if isinstance(arg, torch.Tensor) and arg.requires_grad: - input_requires_grad = True - break - assert input_requires_grad, ( - "`use_reentrant=True` requires at least one input tensor with `requires_grad=True`." - ) - return _CheckpointFunction.apply( function, distribute_saved_activations, From 16a469df6bbc77e1c32e48e8e5fd3082dbc2d18e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pawe=C5=82=20Gadzi=C5=84ski?= <62263673+pggPL@users.noreply.github.com> Date: Sun, 31 Mar 2024 12:11:43 -0700 Subject: [PATCH 14/16] Llama tutorial fixes (#730) Llama tutorial fixes - all Signed-off-by: Pawel Gadzinski Co-authored-by: Pawel Gadzinski --- docs/examples/te_llama/te_llama.py | 46 +++++++++++-------- ...tutorial_accelerate_hf_llama_with_te.ipynb | 9 ++-- docs/examples/te_llama/utils.py | 1 + 3 files changed, 34 insertions(+), 22 deletions(-) mode change 100644 => 100755 docs/examples/te_llama/te_llama.py mode change 100644 => 100755 docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb mode change 100644 => 100755 docs/examples/te_llama/utils.py diff --git a/docs/examples/te_llama/te_llama.py b/docs/examples/te_llama/te_llama.py old mode 100644 new mode 100755 index c73bed45b4..aa23b638f0 --- a/docs/examples/te_llama/te_llama.py +++ b/docs/examples/te_llama/te_llama.py @@ -56,7 +56,7 @@ def __init__(self, config, *args, **kwargs): normalization="RMSNorm", activation="swiglu", attn_input_format="bshd", - num_gqa_groups=config.num_key_value_heads, + num_gqa_groups=config.num_key_value_heads ) te_rope = RotaryPositionEmbedding(config.hidden_size//config.num_attention_heads) self.te_rope_emb = te_rope(max_seq_len=config.max_position_embeddings).cuda() @@ -121,12 +121,12 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k assert not isinstance(resolved_archive_file, list) resolved_archive_file = [resolved_archive_file] - error_msgs = [] for shard_file in resolved_archive_file: state_dict = load_state_dict(shard_file) - replaced_layers = replace_params(state_dict, vanilla_model.state_dict()) - - error_msgs += _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") + # replace_params copies parameters relevant only to TransformerEngine + replace_params(state_dict, vanilla_model.state_dict(), config) + # _load_state_dict_into_model copies parameters other than those in TransformerEngine + _load_state_dict_into_model(vanilla_model, state_dict, start_prefix="") # Force mem release. Taken from huggingface code del state_dict @@ -134,7 +134,7 @@ def from_pretrained_local(cls, pretrained_model_name_or_path, *args, config, **k return vanilla_model -def replace_params(hf_state_dict, te_state_dict): +def replace_params(hf_state_dict, te_state_dict, config): # collect all layer prefixes to update all_layer_prefixes = set() for param_key in hf_state_dict.keys(): @@ -142,32 +142,40 @@ def replace_params(hf_state_dict, te_state_dict): m = re.match(layer_prefix_pat, param_key) if m is not None: all_layer_prefixes.add(m.group()) + + for layer_prefix in all_layer_prefixes: # When loading weights into models with less number of layers, skip the - # copy if the corresponding layer doesn't exist in TE model - if layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight' in te_state_dict: + # copy if the corresponding layer doesn't exist in HF model + if layer_prefix + 'input_layernorm.weight' in hf_state_dict: te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'input_layernorm.weight'].data[:] - if layer_prefix + 'self_attention.layernorm_qkv.query_weight' in te_state_dict: + if layer_prefix + 'self_attn.q_proj.weight' in hf_state_dict: te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.query_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.q_proj.weight'].data[:] - if layer_prefix + 'self_attention.layernorm_qkv.key_weight' in te_state_dict: + if layer_prefix + 'self_attn.k_proj.weight' in hf_state_dict: te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.key_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.k_proj.weight'].data[:] - if layer_prefix + 'self_attention.layernorm_qkv.value_weight' in te_state_dict: + if layer_prefix + 'self_attn.v_proj.weight' in hf_state_dict: te_state_dict[layer_prefix + 'self_attention.layernorm_qkv.value_weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.v_proj.weight'].data[:] - if layer_prefix + 'self_attention.proj.weight' in te_state_dict: + if layer_prefix + 'self_attn.o_proj.weight' in hf_state_dict: te_state_dict[layer_prefix + 'self_attention.proj.weight'].data[:] = hf_state_dict[layer_prefix + 'self_attn.o_proj.weight'].data[:] - if layer_prefix + 'layernorm_mlp.layer_norm_weight' in te_state_dict: + if layer_prefix + 'post_attention_layernorm.weight' in hf_state_dict: te_state_dict[layer_prefix + 'layernorm_mlp.layer_norm_weight'].data[:] = hf_state_dict[layer_prefix + 'post_attention_layernorm.weight'].data[:] - - if layer_prefix + 'layernorm_mlp.fc1_weight' in te_state_dict: - te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:] = torch.cat((hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data[:], hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data[:]), dim=0) - - if layer_prefix + 'layernorm_mlp.fc2_weight' in te_state_dict: + + # It may happen that gate_proj.weight and up_proj.weight will be in the different files, so we need to + # load them separately. + if layer_prefix + 'mlp.gate_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[:config.intermediate_size] = \ + hf_state_dict[layer_prefix + 'mlp.gate_proj.weight'].data + + if layer_prefix + 'mlp.up_proj.weight' in hf_state_dict: + te_state_dict[layer_prefix + 'layernorm_mlp.fc1_weight'].data[config.intermediate_size:] = \ + hf_state_dict[layer_prefix + 'mlp.up_proj.weight'].data + + if layer_prefix + 'mlp.down_proj.weight' in hf_state_dict: te_state_dict[layer_prefix + 'layernorm_mlp.fc2_weight'].data[:] = hf_state_dict[layer_prefix + 'mlp.down_proj.weight'].data[:] - return all_layer_prefixes \ No newline at end of file diff --git a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb old mode 100644 new mode 100755 index 178922c9d2..cc77b484f9 --- a/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb +++ b/docs/examples/te_llama/tutorial_accelerate_hf_llama_with_te.ipynb @@ -231,7 +231,8 @@ "\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", + "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n", + "## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n", "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", @@ -556,7 +557,8 @@ "\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", + "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n", + "## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n", "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.mixed_precision = \"bf16\"\n", "\n", @@ -635,7 +637,8 @@ "\n", "# Default hyperparams, also defined in `utils.py` in class `Hyperparameters`\n", "## !!! `model_name` attr must point to the location of the model weights !!!\n", - "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/\n", + "## Weights can be downloaded from: https://llama.meta.com/llama-downloads/ and then coverted to the HuggingFace format.\n", + "## Instructions for conversion are available on the website https://ai.meta.com/blog/5-steps-to-getting-started-with-llama-2/ - steps 1 and 2.\n", "hyperparams.model_name = \"\" # <== Add model weight location here e.g. \"/path/to/downloaded/llama/weights\"\n", "hyperparams.mixed_precision = \"fp8\"\n", "\n", diff --git a/docs/examples/te_llama/utils.py b/docs/examples/te_llama/utils.py old mode 100644 new mode 100755 index 54b329f12b..9c36e5bd17 --- a/docs/examples/te_llama/utils.py +++ b/docs/examples/te_llama/utils.py @@ -91,6 +91,7 @@ def init_te_llama_model(hyperparams): # Init the model from te_llama import TELlamaForCausalLM config = AutoConfig.from_pretrained(hyperparams.model_name) + config._attn_implementation = "flash_attention_2" model = TELlamaForCausalLM.from_pretrained_local( hyperparams.model_name, config=config, From 2dd6b14657723bd9e389c22d52730e0a5fe0e204 Mon Sep 17 00:00:00 2001 From: Tim Moon <4406448+timmoon10@users.noreply.github.com> Date: Tue, 2 Apr 2024 10:16:04 -0700 Subject: [PATCH 15/16] Set CUDA context before loading NVRTC kernels (#734) Signed-off-by: Tim Moon --- transformer_engine/common/util/rtc.cpp | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/transformer_engine/common/util/rtc.cpp b/transformer_engine/common/util/rtc.cpp index d0fcbe61cd..87271e97b0 100644 --- a/transformer_engine/common/util/rtc.cpp +++ b/transformer_engine/common/util/rtc.cpp @@ -76,6 +76,10 @@ Kernel::~Kernel() { != CUDA_SUCCESS) { continue; } + if (cuda_driver::call("cuCtxSetCurrent", context) + != CUDA_SUCCESS) { + continue; + } cuda_driver::call("cuModuleUnload", modules_[device_id]); cuda_driver::call("cuDevicePrimaryCtxRelease", device); } @@ -109,6 +113,7 @@ CUfunction Kernel::get_function(int device_id) { CUcontext context; NVTE_CALL_CHECK_CUDA_DRIVER(cuDeviceGet, &device, device_id); NVTE_CALL_CHECK_CUDA_DRIVER(cuDevicePrimaryCtxRetain, &context, device); + NVTE_CALL_CHECK_CUDA_DRIVER(cuCtxSetCurrent, context); // Load function into driver context NVTE_CALL_CHECK_CUDA_DRIVER(cuModuleLoadDataEx, From 47276e1bfbde0423f0fbc698e36f2652a1d9895e Mon Sep 17 00:00:00 2001 From: Kirthi Shankar Sivamani Date: Tue, 2 Apr 2024 19:07:50 -0700 Subject: [PATCH 16/16] Revert "Update FA version to 2.5.6 (#714)" This reverts commit 965803c9dcf889f7e9820921cb811c2bf9b91dbc. --- setup.py | 2 +- transformer_engine/pytorch/attention.py | 4 ---- 2 files changed, 1 insertion(+), 5 deletions(-) diff --git a/setup.py b/setup.py index e6bfe496ff..d50a9b8706 100644 --- a/setup.py +++ b/setup.py @@ -265,7 +265,7 @@ def add_unique(l: List[str], vals: Union[str, List[str]]) -> None: # Framework-specific requirements if "pytorch" in frameworks(): - add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.5.6,!=2.0.9,!=2.1.0"]) + add_unique(install_reqs, ["torch", "flash-attn>=2.0.6,<=2.4.2,!=2.0.9,!=2.1.0"]) add_unique(test_reqs, ["numpy", "onnxruntime", "torchvision"]) if "jax" in frameworks(): if not found_pybind11(): diff --git a/transformer_engine/pytorch/attention.py b/transformer_engine/pytorch/attention.py index 924e2bb97d..f5e7753e6a 100644 --- a/transformer_engine/pytorch/attention.py +++ b/transformer_engine/pytorch/attention.py @@ -58,7 +58,6 @@ _flash_attn_version = packaging.version.Version(version("flash-attn")) _flash_attn_version_required = packaging.version.Version("2.0.6") -_flash_attn_max_version = packaging.version.Version("2.5.6") _flash_attn_2_1_plus = _flash_attn_version >= packaging.version.Version("2.1") _flash_attn_2_3_plus = _flash_attn_version >= packaging.version.Version("2.3") _flash_attn_2_4_plus = _flash_attn_version >= packaging.version.Version("2.4") @@ -1658,9 +1657,6 @@ def __init__( assert ( _flash_attn_version >= _flash_attn_version_required ), f"FlashAttention minimum version {_flash_attn_version_required} is required." - assert ( - _flash_attn_version <= _flash_attn_max_version - ), f"FlashAttention maximum version {_flash_attn_max_version} is supported." self.norm_factor = norm_factor self.attention_dropout_ctx = attention_dropout_ctx