diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 580125290b..d5b23667d1 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -16,7 +16,7 @@ from utils import make_causal_mask, make_self_mask from transformer_engine.jax import fp8_autocast from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available -from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn +from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout DTYPES = [jnp.float16, jnp.bfloat16] @@ -86,15 +86,15 @@ def test_self_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, dat def target_func(qkv, bias, mask): return jnp.mean( - self_fused_attn(qkv, - bias, - mask, - None, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training)) + fused_attn_qkvpacked(qkv, + bias, + mask, + None, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_prob, + is_training=is_training)) def ref_func(qkv, bias, mask): query, key, value = jnp.split(qkv, [1, 2], axis=-3) @@ -192,16 +192,16 @@ def test_cross_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, da def target_func(q, kv, mask): return jnp.mean( - cross_fused_attn(q, - kv, - None, - mask, - None, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_prob, - is_training=is_training)) + fused_attn_kvpacked(q, + kv, + None, + mask, + None, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_prob, + is_training=is_training)) def ref_func(query, kv, mask): key, value = jnp.split(kv, [1], axis=-3) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 35a1e4218f..483f070559 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -2,8 +2,6 @@ # # See LICENSE for license information. """Tests for fused attention""" -import sys - from enum import Enum from dataclasses import dataclass from functools import partial @@ -21,7 +19,7 @@ from jax.typing import ArrayLike, DTypeLike from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout -from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn, fused_attn +from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn from transformer_engine.jax.cpp_extensions import FusedAttnHelper from transformer_engine_jax import NVTE_Fused_Attn_Backend @@ -144,18 +142,22 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng case QKVLayout.BS3HD: query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value]) qkv = jnp.concatenate((query, key, value), axis=-3) - return self_fused_attn(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype) + return fused_attn_qkvpacked(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype) case QKVLayout.BSHD_BS2HD: key, value = map(partial(jnp.expand_dims, axis=-3), [key, value]) kv = jnp.concatenate((key, value), axis=-3) - return cross_fused_attn(query, kv, bias, mask, dropout_rng, - **kwargs).astype(query.dtype) + return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng, + **kwargs).astype(query.dtype) case QKVLayout.BSHD_BSHD_BSHD: return fused_attn(query, key, value, bias, mask, dropout_rng, **kwargs).astype(query.dtype) class BiasShape(Enum): + """ + Enum class to represent the different bias shapes used in the fused attention. + """ + BIAS_1HSS = '1HSS' BIAS_B1SS = 'B1SS' BIAS_BHSS = 'BHSS' @@ -188,17 +190,16 @@ def _check_configs(self): if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv: pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.") - self.backend = FusedAttnHelper( - self.dtype, self.dtype, self.qkv_layout.value, self.attn_bias_type.value, - self.attn_mask_type.value, self.dropout_prob, self.num_heads_q, self.num_heads_kv, - self.max_seqlen_q, self.max_seqlen_kv, self.head_dim).get_fused_attn_backend() + self.backend = FusedAttnHelper(self.dtype, self.dtype, self.qkv_layout.value, + self.attn_bias_type.value, self.attn_mask_type.value, + self.dropout_prob, self.num_heads_q, self.num_heads_kv, + self.max_seqlen_q, self.max_seqlen_kv, + self.head_dim).get_fused_attn_backend() if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend: pytest.skip("Unsupported inputs combination or device compute capability.") - if self.bias_shape != BiasShape.BIAS_1HSS: - if self.attn_bias_type != AttnBiasType.POST_SCALE_BIAS: - pytest.skip("B1SS, BHSS and 11SS bias shapes require POST_SCALE_BIAS.") - elif self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: + if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS: + if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for " "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.") elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: @@ -213,7 +214,9 @@ def _setup_inputs(self): q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim) k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim) - if self.bias_shape == BiasShape.BIAS_1HSS: + if self.attn_bias_type == AttnBiasType.NO_BIAS: + bias_shape = None + elif self.bias_shape == BiasShape.BIAS_1HSS: bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv) elif self.bias_shape == BiasShape.BIAS_B1SS: bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv) @@ -222,7 +225,7 @@ def _setup_inputs(self): elif self.bias_shape == BiasShape.BIAS_11SS: bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv) else: - pytest.xfail("PyTest attempted to use an unrecognized bias layout!") + pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!") self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.) self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.) @@ -237,7 +240,7 @@ def _setup_inputs(self): cudnn_neg_inf = -2.**27. if self.dtype == jnp.bfloat16 else -2.**15. self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype) max_id = min(self.max_seqlen_q, self.max_seqlen_kv) - seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences + seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist() for i in range(1, len(seq_id)): self.bias = \ @@ -327,8 +330,8 @@ def grad_func(func, *args, **kwargs): **kwargs), arg_nums)) jitted_reference = jit( value_and_grad( - lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, - **kwargs), arg_nums)) + lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs), + arg_nums)) primitive_out, primitive_dgrad = jitted_primitive(*args) reference_out, reference_dgrad = jitted_reference(*args) @@ -361,10 +364,10 @@ def check_dqkv(primitive, reference, valid_len): primitive_dbias = jnp.float32(primitive_dgrad[3]) reference_dbias = jnp.float32(reference_dgrad[3]) - assert_allclose( - primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:], - jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:]), - dtype=self.dtype) + assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:], + jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, + self.valid_len_kv:]), + dtype=self.dtype) # dbias padded part assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:], @@ -376,15 +379,13 @@ def check_dqkv(primitive, reference, valid_len): reference_dbias[..., :self.valid_len_q, :self.valid_len_kv], dtype=self.dtype) -@pytest.mark.parametrize('bias_shape', [ - pytest.param(BiasShape.BIAS_1HSS, id='1-H-S-S'), - pytest.param(BiasShape.BIAS_B1SS, id='B-1-S-S'), - pytest.param(BiasShape.BIAS_BHSS, id='B-H-S-S'), - pytest.param(BiasShape.BIAS_11SS, id='1-1-S-S'), -]) -@pytest.mark.parametrize('attn_bias_type', [ - pytest.param(AttnBiasType.NO_BIAS, id='NO_BIAS'), - pytest.param(AttnBiasType.POST_SCALE_BIAS, id='POST_SCALE_BIAS'), + +@pytest.mark.parametrize('attn_bias_type, bias_shape', [ + pytest.param(AttnBiasType.NO_BIAS, None, id='NO_BIAS'), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id='POST_SCALE_BIAS-1HSS'), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id='POST_SCALE_BIAS-B1SS'), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id='POST_SCALE_BIAS-BHSS'), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id='POST_SCALE_BIAS-11SS'), ]) @pytest.mark.parametrize('attn_mask_type', [ pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'), @@ -399,31 +400,32 @@ def check_dqkv(primitive, reference, valid_len): ]) @pytest.mark.parametrize('dtype', [ pytest.param(jnp.bfloat16, id="BF16"), - pytest.param(jnp.float16, id="FP16") + pytest.param(jnp.float16, id="FP16"), ]) -@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d',[ - pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'), - pytest.param( 4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'), - pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'), - pytest.param( 4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'), - pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'), - pytest.param( 4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA') +@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d', [ + pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'), + pytest.param(4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'), + pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'), + pytest.param(4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'), + pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'), + pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA'), ]) @pytest.mark.parametrize('dropout_prob', [ pytest.param(0.0, id="DROP_0.0"), - pytest.param(0.1, id="DROP_0.1") -]) -@pytest.mark.parametrize('is_training', [ - pytest.param(True, id='TRAINING'), - pytest.param(False, id='INFERENCE'), + pytest.param(0.1, id="DROP_0.1"), ]) class TestFusedAttn: """ Fused attention tester """ + @staticmethod - def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, - dropout_prob, dtype, is_training, qkv_layout, bias_shape): + @pytest.mark.parametrize('is_training', [ + pytest.param(True, id='TRAINING'), + pytest.param(False, id='INFERENCE'), + ]) + def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob, + dtype, is_training, qkv_layout, bias_shape): """ Test forward with parameterized configs """ @@ -432,13 +434,11 @@ def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, runner.test_forward() @staticmethod - def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, - dropout_prob, dtype, is_training, qkv_layout, bias_shape): + def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob, + dtype, qkv_layout, bias_shape): """ Test backward with parameterized configs """ - if not is_training: - pytest.skip("Backward pass does not support inference.") runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob, dtype, True, qkv_layout, bias_shape) runner.test_backward() diff --git a/tests/jax/test_layer.py b/tests/jax/test_layer.py index 175819e115..1b7b4087d0 100644 --- a/tests/jax/test_layer.py +++ b/tests/jax/test_layer.py @@ -449,6 +449,7 @@ def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): hidden_dropout_dims=(sequence_dim,), intermediate_dropout_dims=(sequence_dim,), layer_type=TransformerLayerType.DECODER, + self_attn_mask_type='padding_causal', dtype=dtype, **te_layer_attrs) ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs, @@ -497,6 +498,7 @@ def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e- hidden_dropout_dims=(sequence_dim,), intermediate_dropout_dims=(sequence_dim,), layer_type=TransformerLayerType.DECODER, + self_attn_mask_type='padding_causal', dtype=dtype, **te_layer_attrs) ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs, diff --git a/tests/jax/test_praxis_layers.py b/tests/jax/test_praxis_layers.py index d9c9e17e97..43581f1015 100644 --- a/tests/jax/test_praxis_layers.py +++ b/tests/jax/test_praxis_layers.py @@ -730,8 +730,13 @@ class TestDotProductAttn(TestLayer): def input_getter(self, shape, dtype): key = jax.random.PRNGKey(seed=1234) q_key, k_key, v_key = jax.random.split(key, 3) - return list(map(partial(jax.random.normal, shape=shape, dtype=dtype), - [q_key, k_key, v_key])) + b, s, *_ = shape + if self.attrs[DotProductAttnAttr.TRANSPOSE_BS]: + b, s = s, b + mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8) + return [ + *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask + ] def get_layer_name(self): return 'dot_product_attn' @@ -765,6 +770,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): @pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS) def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): + self.attrs = attrs praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) @@ -853,9 +859,11 @@ class MultiHeadAttnAttr: class TestMultiHeadAttn(TestLayer): def input_getter(self, shape, dtype): - data_key = jax.random.PRNGKey(seed=1234) - return (jax.random.normal(data_key, shape, - dtype), jax.random.normal(data_key, shape, dtype)) + key = jax.random.PRNGKey(seed=1234) + q_key, kv_key = jax.random.split(key, 2) + s, b, *_ = shape + mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8) + return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask] def get_layer_name(self): return 'multi_head_attn' @@ -1183,9 +1191,15 @@ class TransformerLayerAttr: class TestTransformer(TestLayer): def input_getter(self, shape, dtype): - data_key = jax.random.PRNGKey(seed=1234) - return (jax.random.normal(data_key, shape, - dtype), jax.random.normal(data_key, shape, dtype)) + key = jax.random.PRNGKey(seed=1234) + q_key, kv_key = jax.random.split(key, 2) + b, s, *_ = shape + if self.attrs[TransformerLayerAttr.TRANSPOSE_BS]: + b, s = s, b + mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8) + return [ + *map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask + ] def get_layer_name(self): return 'transformerlayer' @@ -1277,6 +1291,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs): @pytest.mark.parametrize('dtype', DTYPE) @pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS) def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08): + self.attrs = attrs praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol) @@ -1292,7 +1307,7 @@ def test_forward_backward_fp8(self, fp8_format, rtol=1e-05, atol=1e-08): - + self.attrs = attrs ds = DelayedScaling(fp8_format=fp8_format) with fp8_autocast(enabled=True, fp8_recipe=ds): praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs) diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index ba1d44318c..08bcb94239 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -1368,14 +1368,13 @@ def batcher(batched_args, batch_dims, *, scale_factor): @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledSoftmaxFwdPrimitive.forward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos - ) + scale_factor, mesh, arg_infos, result_infos) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): - return ScaledSoftmaxFwdPrimitive.forward_partition( - ScaledSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos - ) + return ScaledSoftmaxFwdPrimitive.forward_partition(ScaledSoftmaxFwdPrimitive.impl, + scale_factor, mesh, arg_infos, + result_infos) register_primitive(ScaledSoftmaxFwdPrimitive) @@ -1444,14 +1443,13 @@ def batcher(batched_args, batch_dims, *, scale_factor): @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledSoftmaxBwdPrimitive.backward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos - ) + scale_factor, mesh, arg_infos, result_infos) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): - return ScaledSoftmaxBwdPrimitive.backward_partition( - ScaledSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos - ) + return ScaledSoftmaxBwdPrimitive.backward_partition(ScaledSoftmaxBwdPrimitive.impl, + scale_factor, mesh, arg_infos, + result_infos) register_primitive(ScaledSoftmaxBwdPrimitive) @@ -1581,14 +1579,12 @@ def batcher(batched_args, batch_dims, *, scale_factor): @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos,result_infos - ) + scale_factor, mesh, arg_infos, result_infos) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledMaskedSoftmaxFwdPrimitive.backward_partition( - ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos - ) + ScaledMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos) register_primitive(ScaledMaskedSoftmaxFwdPrimitive) @@ -1660,14 +1656,12 @@ def batcher(batched_args, batch_dims, *, scale_factor): @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos - ) + scale_factor, mesh, arg_infos, result_infos) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledMaskedSoftmaxBwdPrimitive.backward_partition( - ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos - ) + ScaledMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, result_infos) register_primitive(ScaledMaskedSoftmaxBwdPrimitive) @@ -1749,15 +1743,13 @@ def batcher(batched_args, batch_dims, *, scale_factor): @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos - ) + scale_factor, mesh, arg_infos, result_infos) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledUpperTriangMaskedSoftmaxFwdPrimitive.forward_partition( - ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, - arg_infos, result_infos - ) + ScaledUpperTriangMaskedSoftmaxFwdPrimitive.impl, scale_factor, mesh, arg_infos, + result_infos) register_primitive(ScaledUpperTriangMaskedSoftmaxFwdPrimitive) @@ -1829,15 +1821,13 @@ def batcher(batched_args, batch_dims, *, scale_factor): @staticmethod def infer_sharding_from_operands(scale_factor, mesh, arg_infos, result_infos): return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_infer_sharding_from_operands( - scale_factor, mesh, arg_infos, result_infos - ) + scale_factor, mesh, arg_infos, result_infos) @staticmethod def partition(scale_factor, mesh, arg_infos, result_infos): return ScaledUpperTriangMaskedSoftmaxBwdPrimitive.backward_partition( - ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, - arg_infos, result_infos - ) + ScaledUpperTriangMaskedSoftmaxBwdPrimitive.impl, scale_factor, mesh, arg_infos, + result_infos) register_primitive(ScaledUpperTriangMaskedSoftmaxBwdPrimitive) @@ -1859,16 +1849,16 @@ class FusedAttnHelper: Helper for the fused attention backend """ - q_type: jnp.dtype - kv_type: jnp.dtype + q_dtype: jnp.dtype + kv_dtype: jnp.dtype qkv_layout: NVTE_QKV_Layout attn_bias_type: NVTE_Bias_Type attn_mask_type: NVTE_Mask_Type dropout_probability: float - num_heads_q: int - num_heads_kv: int - max_seqlen_q: int - max_seqlen_kv: int + q_num_heads: int + kv_num_heads: int + q_max_seqlen: int + kv_max_seqlen: int head_dim: int def is_fused_attn_kernel_available(self): @@ -1878,11 +1868,38 @@ def is_fused_attn_kernel_available(self): def get_fused_attn_backend(self): """Get the fused attention kernel backend""" return transformer_engine_jax.get_fused_attn_backend( - jax_dtype_to_te_dtype(self.q_type), jax_dtype_to_te_dtype(self.kv_type), + jax_dtype_to_te_dtype(self.q_dtype), jax_dtype_to_te_dtype(self.kv_dtype), self.qkv_layout, self.attn_bias_type, self.attn_mask_type, self.dropout_probability, - self.num_heads_q, self.num_heads_kv, self.max_seqlen_q, self.max_seqlen_kv, + self.q_num_heads, self.kv_num_heads, self.q_max_seqlen, self.kv_max_seqlen, self.head_dim) + @staticmethod + def parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout): + """Parse qkv aval""" + match qkv_layout: + case NVTE_QKV_Layout.NVTE_BS3HD: + *q_batch_shape, q_max_seqlen, nqkv, attn_heads, q_head_dim = q_aval.shape + kv_batch_shape = q_batch_shape + kv_max_seqlen = q_max_seqlen + num_gqa_groups = attn_heads + kv_head_dim = q_head_dim + assert nqkv == 3 + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape + *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = k_aval.shape + assert nkv == 2 + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape + *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape + assert k_aval.shape == v_aval.shape + case _: + raise ValueError(f"Unexpected {qkv_layout=}") + assert q_batch_shape == kv_batch_shape + assert q_head_dim == kv_head_dim + assert q_aval.dtype == k_aval.dtype == v_aval.dtype + + return (q_batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, q_head_dim) + @dataclass(frozen=True) class _FusedAttnRNGStateChecker: @@ -1933,46 +1950,50 @@ def generate_cu_seqlen(actual_seqlen): return cu_seqlen -class SelfFusedAttnFwdPrimitive(BasePrimitive): +class FusedAttnFwdPrimitive(BasePrimitive): """ - Self Fused Attention Forward Primitive + Fused Attention Forward Primitive """ - name = "te_self_fused_attn_forward" + name = "te_fused_attn_forward" multiple_results = True - impl_static_args = (4, 5, 6, 7, 8) + impl_static_args = (7, 8, 9, 10, 11, 12) inner_primitive = None outer_primitive = None @staticmethod - def abstract(qkv_aval, bias_aval, seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, - attn_mask_type, scaling_factor, dropout_probability, is_training): + def abstract(q_aval, k_aval, v_aval, bias_aval, q_seqlen_or_cu_seqlen_aval, + kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type, + qkv_layout, scaling_factor, dropout_probability, is_training): """ - Self fused attention fwd inner primitive abstract + Fused attention fwd abstract """ - # outer_primitve is squeezed_mask, inner_primitive is cu_seqlen - del seqlen_or_cu_seqlen_aval - qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype) - *input_batch_shape, max_seqlen, nqkv, attn_heads, head_dim = qkv_aval.shape - assert nqkv == 3 - assert qkv_aval.dtype == bias_aval.dtype + q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) + k_dtype = dtypes.canonicalize_dtype(k_aval.dtype) + v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) + bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) + assert q_dtype == k_dtype == v_dtype == bias_dtype + assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype + + batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \ + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) - output_shape = (*input_batch_shape, max_seqlen, attn_heads, head_dim) - out_aval = qkv_aval.update(shape=output_shape, dtype=qkv_dtype) + output_shape = (*batch_shape, q_max_seqlen, attn_heads, head_dim) + out_aval = q_aval.update(shape=output_shape, dtype=q_dtype) # backend determines the softmax buffer shape/dtype - backend = FusedAttnHelper(qkv_dtype, qkv_dtype, NVTE_QKV_Layout.NVTE_BS3HD, attn_bias_type, - attn_mask_type, dropout_probability, attn_heads, attn_heads, - max_seqlen, max_seqlen, head_dim).get_fused_attn_backend() + backend = FusedAttnHelper(q_dtype, k_dtype, qkv_layout, attn_bias_type, attn_mask_type, + dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, + kv_max_seqlen, head_dim).get_fused_attn_backend() if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: - softmax_shape = (*input_batch_shape, attn_heads, max_seqlen, max_seqlen) - softmax_dtype = qkv_dtype + softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) + softmax_dtype = q_dtype elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - softmax_shape = (*input_batch_shape, attn_heads, max_seqlen, 1) + softmax_shape = (*batch_shape, attn_heads, q_max_seqlen, 1) softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) else: raise ValueError(f'Unsupported {backend=}') - softmax_aux_aval = qkv_aval.update(shape=softmax_shape, dtype=softmax_dtype) + softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with # 32-bit unsigned int to get the buffer size we need in the C++ kernel @@ -1990,32 +2011,32 @@ def abstract(qkv_aval, bias_aval, seqlen_or_cu_seqlen_aval, seed_aval, *, attn_b # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to # prepare for the active fused-attn backend - input_batch = reduce(operator.mul, input_batch_shape) - wkspace_info = transformer_engine_jax.get_self_fused_attn_fwd_workspace_sizes( - input_batch, bias_batch, max_seqlen, attn_heads, bias_heads, head_dim, - scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(qkv_aval.dtype), is_training) - wkspace_aval = qkv_aval.update(shape=wkspace_info[0], - dtype=te_dtype_to_jax_dtype(wkspace_info[1])) + input_batch = reduce(operator.mul, batch_shape) + wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes( + input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, + bias_heads, head_dim, scaling_factor, dropout_probability, attn_bias_type, + attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), is_training) + wkspace_aval = q_aval.update(shape=wkspace_info[0], + dtype=te_dtype_to_jax_dtype(wkspace_info[1])) return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ - Self fused attention fwd outer primitive abstract + Fused attention fwd outer primitive abstract """ out_aval, softmax_aux_aval, rng_state_aval, _ = \ - SelfFusedAttnFwdPrimitive.abstract(*args, **kwargs) + FusedAttnFwdPrimitive.abstract(*args, **kwargs) return out_aval, softmax_aux_aval, rng_state_aval @staticmethod - def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): + def lowering(ctx, q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type, + attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training): """ - Self fused attention fwd lowering rules + Fused attention fwd lowering rules """ - operands = [qkv, bias, cu_seqlen, seed] + operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed] operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) @@ -2023,9 +2044,12 @@ def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, ] args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - qkv_aval, bias_aval, *_ = ctx.avals_in - *input_batch_shape, max_seqlen, _, attn_heads, head_dim = qkv_aval.shape - input_batch = reduce(operator.mul, input_batch_shape) + q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in + + batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \ + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + + input_batch = reduce(operator.mul, batch_shape) if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 @@ -2036,137 +2060,137 @@ def lowering(ctx, qkv, bias, cu_seqlen, seed, *, attn_bias_type, attn_mask_type, wkspace_aval = ctx.avals_out[-1] opaque = transformer_engine_jax.pack_fused_attn_descriptor( - input_batch, bias_batch, max_seqlen, max_seqlen, - attn_heads, attn_heads, bias_heads, head_dim, wkspace_aval.size, - scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(qkv_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - is_training) + input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, + bias_heads, head_dim, wkspace_aval.size, scaling_factor, dropout_probability, + attn_bias_type, attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), is_training) - out = custom_caller(SelfFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) + out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) return out @staticmethod - def impl(qkv, bias, seqlen, seed, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): - assert SelfFusedAttnFwdPrimitive.inner_primitive is not None + def impl(q, k, v, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type, qkv_layout, + scaling_factor, dropout_probability, is_training): + assert FusedAttnFwdPrimitive.inner_primitive is not None - cu_seqlen = generate_cu_seqlen(seqlen) + q_cu_seqlen = generate_cu_seqlen(q_seqlen) + kv_cu_seqlen = generate_cu_seqlen(kv_seqlen) - output, softmax_aux, rng_state, _ = SelfFusedAttnFwdPrimitive.inner_primitive.bind( - qkv, + output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind( + q, + k, + v, bias, - cu_seqlen, + q_cu_seqlen, + kv_cu_seqlen, seed, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, scaling_factor=scaling_factor, dropout_probability=dropout_probability, is_training=is_training) return output, softmax_aux, rng_state @staticmethod - def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): + def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, qkv_layout, + scaling_factor, dropout_probability, is_training): _check_valid_batch_dims(batch_dims) - assert SelfFusedAttnFwdPrimitive.outer_primitive is not None - qkv_bdim, _, _, seed_bdim = batch_dims + assert FusedAttnFwdPrimitive.outer_primitive is not None + q_bdim, *_, seed_bdim = batch_dims - out_bdims = qkv_bdim, qkv_bdim, seed_bdim - return SelfFusedAttnFwdPrimitive.outer_primitive.bind( - *batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training), out_bdims + out_bdims = q_bdim, q_bdim, seed_bdim + return FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training), out_bdims @staticmethod - def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, + def infer_sharding_from_operands(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training, mesh, arg_infos, result_infos): del attn_bias_type, attn_mask_type, scaling_factor del dropout_probability, is_training, result_infos - x_spec = get_padded_spec(arg_infos[0]) # (...batch, seqlen, 3, head, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-3], *x_spec[-2:])) - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec[:-4], x_spec[-2], x_spec[-4], None)) + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + match qkv_layout: + case NVTE_QKV_Layout.NVTE_BS3HD: + # q_spec = (...batch, q_seqlen, head, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec[:-3], *q_spec[-2:])) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-4], q_spec[-2], q_spec[-4], None)) + case NVTE_QKV_Layout.NVTE_BSHD_BS2HD: + # q_spec = (...batch, q_seqlen, head, hidden) + # k_spec = (...batch, kv_seqlen, 2, num_gqa_groups, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-4])) + case NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD: + # q_spec = (...batch, q_seqlen, head, hidden) + # k_spec = (...batch, kv_seqlen, num_gqa_groups, hidden) + out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + softmax_aux_sharding = NamedSharding( + mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3])) + case _: + raise ValueError(f"Unsupported {qkv_layout=}") rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) return (out_sharding, softmax_aux_sharding, rng_state_sharding) @staticmethod - def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, - mesh, arg_infos, result_infos): - del result_infos - x_spec = get_padded_spec(arg_infos[0]) # (...batch, seqlen, 3, head, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec[:-3], *x_spec[-2:])) - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*x_spec[:-4], x_spec[-2], x_spec[-4], None)) - rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) - arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [rng_state_sharding]) + def partition(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability, + is_training, mesh, arg_infos, result_infos): + out_sharding = result_infos[0].sharding + softmax_aux_sharding = result_infos[1].sharding + rng_state_sharding = seed_sharding = NamedSharding(mesh, + PartitionSpec(get_all_mesh_axes(), None)) + arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - impl = partial(SelfFusedAttnFwdPrimitive.impl, + impl = partial(FusedAttnFwdPrimitive.impl, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, scaling_factor=scaling_factor, dropout_probability=dropout_probability, is_training=is_training) return mesh, impl, out_shardings, arg_shardings -register_primitive(SelfFusedAttnFwdPrimitive) - - -def self_fused_attn_fwd(qkv: jnp.ndarray, bias: jnp.ndarray | None, seqlen: jnp.ndarray, - seed: jnp.ndarray | None, attn_bias_type: NVTE_Bias_Type, - attn_mask_type: NVTE_Mask_Type, scaling_factor: float, - dropout_probability: float, is_training: bool): - """ - Wrapper for TE self fused attention fwd - Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2 - """ - checker = _FusedAttnRNGStateChecker() - seed = checker.check_seed(seed, dropout_probability, is_training) - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - assert bias is None - bias = jnp.zeros(0, dtype=qkv.dtype) - - return SelfFusedAttnFwdPrimitive.outer_primitive.bind(qkv, - bias, - seqlen, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) +register_primitive(FusedAttnFwdPrimitive) -class SelfFusedAttnBwdPrimitive(BasePrimitive): +class FusedAttnBwdPrimitive(BasePrimitive): """ - Self Fused Attention Backward Primitive + Fused Attention Backward Primitive """ - name = "te_self_fused_attn_backward" + name = "te_fused_attn_backward" multiple_results = True - impl_static_args = (7, 8, 9, 10, 11) + impl_static_args = (10, 11, 12, 13, 14, 15) inner_primitive = None outer_primitive = None @staticmethod - def abstract(qkv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, doutput_aval, - seqlen_or_cu_seqlen_aval, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): + def abstract(q_aval, k_aval, v_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, + doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type, + attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training): """ - Self fused attention bwd abstract + Fused attention bwd abstract """ - del softmax_aux_aval, rng_state_aval, seqlen_or_cu_seqlen_aval + del softmax_aux_aval, rng_state_aval, output_aval - assert qkv_aval.dtype == bias_aval.dtype == output_aval.dtype == doutput_aval.dtype - *input_batch_shape, max_seqlen, nqkv, attn_heads, head_dim = qkv_aval.shape - assert nqkv == 3 - qkv_dtype = dtypes.canonicalize_dtype(qkv_aval.dtype) + q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) + k_dtype = dtypes.canonicalize_dtype(k_aval.dtype) + v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) + doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype) + assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype + assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype + + batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \ + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 @@ -2174,46 +2198,55 @@ def abstract(qkv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, *bias_batch_shape, bias_heads, _, _ = bias_aval.shape bias_batch = reduce(operator.mul, bias_batch_shape) - input_batch = reduce(operator.mul, input_batch_shape) + input_batch = reduce(operator.mul, batch_shape) wkspace_shape, wkspace_dtype = \ - transformer_engine_jax.get_self_fused_attn_bwd_workspace_sizes( - input_batch, bias_batch, max_seqlen, attn_heads, bias_heads, head_dim, - scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(qkv_aval.dtype), is_training - ) + transformer_engine_jax.get_fused_attn_bwd_workspace_sizes( + input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, + bias_heads, head_dim, scaling_factor, dropout_probability, attn_bias_type, + attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), is_training) - dqkv_aval = qkv_aval.update(shape=qkv_aval.shape, dtype=qkv_dtype) + dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) + dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype) + dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype) dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) - wkspace_aval = qkv_aval.update(shape=wkspace_shape, - dtype=te_dtype_to_jax_dtype(wkspace_dtype)) + wkspace_aval = q_aval.update(shape=wkspace_shape, + dtype=te_dtype_to_jax_dtype(wkspace_dtype)) - return dqkv_aval, dbias_aval, wkspace_aval + return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval @staticmethod def outer_abstract(*args, **kwargs): """ - Self fused attention bwd outer primitive abstract + Fused attention fwd outer primitive abstract """ - dqkv_aval, dbias_aval, _ = SelfFusedAttnBwdPrimitive.abstract(*args, **kwargs) - return dqkv_aval, dbias_aval + dq_aval, dk_aval, dv_aval, dbias_aval, _ = \ + FusedAttnBwdPrimitive.abstract(*args, **kwargs) + return dq_aval, dk_aval, dv_aval, dbias_aval @staticmethod - def lowering(ctx, qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen, *, - attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training): + def lowering(ctx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, + kv_cu_seqlen, *, attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, + dropout_probability, is_training): """ - Self fused attention bwd lowering rules + Fused attention bwd lowering rules """ - operands = [qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen] + operands = [ + q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen + ] operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) for output in ctx.avals_out ] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - qkv_aval, bias_aval, *_ = ctx.avals_in - *input_batch_shape, max_seqlen, _, attn_heads, head_dim = qkv_aval.shape - input_batch = reduce(operator.mul, input_batch_shape) + q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in + + batch_shape, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, head_dim = \ + FusedAttnHelper.parse_qkv_aval(q_aval, k_aval, v_aval, qkv_layout) + + input_batch = reduce(operator.mul, batch_shape) if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: bias_batch = bias_heads = 0 @@ -2224,780 +2257,245 @@ def lowering(ctx, qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen, wkspace_aval = ctx.avals_out[-1] opaque = transformer_engine_jax.pack_fused_attn_descriptor( - input_batch, bias_batch, max_seqlen, max_seqlen, - attn_heads, attn_heads, bias_heads, head_dim, wkspace_aval.size, - scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(qkv_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - is_training) + input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, + bias_heads, head_dim, wkspace_aval.size, scaling_factor, dropout_probability, + attn_bias_type, attn_mask_type, qkv_layout, jax_dtype_to_te_dtype(q_aval.dtype), + jax_dtype_to_te_dtype(wkspace_aval.dtype), is_training) - out = custom_caller(SelfFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) + out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) return out @staticmethod - def impl(qkv, bias, softmax_aux, rng_state, output, doutput, seqlen, attn_bias_type, - attn_mask_type, scaling_factor, dropout_probability, is_training): - assert SelfFusedAttnBwdPrimitive.inner_primitive is not None + def impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen, + attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability, + is_training): + assert FusedAttnBwdPrimitive.inner_primitive is not None - cu_seqlen = generate_cu_seqlen(seqlen) + q_cu_seqlen = generate_cu_seqlen(q_seqlen) + kv_cu_seqlen = generate_cu_seqlen(kv_seqlen) - dqkv, dbias, _ = SelfFusedAttnBwdPrimitive.inner_primitive.bind( - qkv, + dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind( + q, + k, + v, bias, softmax_aux, rng_state, output, doutput, - cu_seqlen, + q_cu_seqlen, + kv_cu_seqlen, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, scaling_factor=scaling_factor, dropout_probability=dropout_probability, is_training=is_training) - return dqkv, dbias + return dq, dk, dv, dbias @staticmethod - def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): + def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, qkv_layout, + scaling_factor, dropout_probability, is_training): _check_valid_batch_dims(batch_dims) - assert SelfFusedAttnBwdPrimitive.outer_primitive is not None - qkv_bdim, *_ = batch_dims + assert FusedAttnBwdPrimitive.outer_primitive is not None + q_bdim, k_bdim, v_bdim, *_ = batch_dims - out_bdims = qkv_bdim, qkv_bdim - return SelfFusedAttnBwdPrimitive.outer_primitive.bind( - *batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training), out_bdims + out_bdims = q_bdim, k_bdim, v_bdim, q_bdim + return FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training), out_bdims @staticmethod - def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, + def infer_sharding_from_operands(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability, is_training, mesh, arg_infos, result_infos): - del attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, - del is_training, result_infos - x_spec = get_padded_spec(arg_infos[0]) - bias_spec = get_padded_spec(arg_infos[1]) - dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + del attn_bias_type, attn_mask_type, qkv_layout, scaling_factor + del dropout_probability, is_training, result_infos + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - return (dx_sharding, dbias_sharding) + return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) @staticmethod - def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, - mesh, arg_infos, result_infos): + def partition(attn_bias_type, attn_mask_type, qkv_layout, scaling_factor, dropout_probability, + is_training, mesh, arg_infos, result_infos): del result_infos - x_spec = get_padded_spec(arg_infos[0]) - bias_spec = get_padded_spec(arg_infos[1]) - dx_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + q_spec = get_padded_spec(arg_infos[0]) + k_spec = get_padded_spec(arg_infos[1]) + v_spec = get_padded_spec(arg_infos[2]) + bias_spec = get_padded_spec(arg_infos[3]) + dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) + dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) + dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = (dx_sharding, dbias_sharding) + out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) - def sharded_impl(qkv, bias, softmax_aux, rng_state, output, doutput, cu_seqlen): - local_dx, local_dbias = SelfFusedAttnBwdPrimitive.impl( - qkv, + def sharded_impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, + kv_cu_seqlen): + local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl( + q, + k, + v, bias, softmax_aux, rng_state, output, doutput, - cu_seqlen, + q_cu_seqlen, + kv_cu_seqlen, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, scaling_factor=scaling_factor, dropout_probability=dropout_probability, is_training=is_training) global_dbias = local_dbias if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) - return local_dx, global_dbias + return local_dq, local_dk, local_dv, global_dbias return mesh, sharded_impl, out_shardings, arg_shardings -register_primitive(SelfFusedAttnBwdPrimitive) +register_primitive(FusedAttnBwdPrimitive) -def self_fused_attn_bwd(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray, - rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray, - seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, - attn_mask_type: NVTE_Mask_Type, scaling_factor: float, - dropout_probability: float, is_training: bool): +def fused_attn_fwd_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray, seqlen: jnp.ndarray, + seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, scaling_factor: float, + dropout_probability: float, is_training: bool): """ - Wrapper for TE self fused attention bwd - Return the gradients of self fused attention with packed qkv input + Wrapper for TE self fused attention fwd + Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 """ + checker = _FusedAttnRNGStateChecker() + seed = checker.check_seed(seed, dropout_probability, is_training) + if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: assert bias is None bias = jnp.zeros(0, dtype=qkv.dtype) - return SelfFusedAttnBwdPrimitive.outer_primitive.bind(qkv, - bias, - softmax_aux, - rng_state, - output, - doutput, - seqlen, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + _not_used = jnp.zeros(0, qkv.dtype) + return FusedAttnFwdPrimitive.outer_primitive.bind(qkv, + _not_used, + _not_used, + bias, + seqlen, + seqlen, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) -class CrossFusedAttnFwdPrimitive(BasePrimitive): +def fused_attn_bwd_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray, softmax_aux: jnp.ndarray, + rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray, + seqlen: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, + attn_mask_type: NVTE_Mask_Type, scaling_factor: float, + dropout_probability: float, is_training: bool): + """ + Wrapper for TE self fused attention bwd + Return the gradients of self fused attention with packed qkv input """ - Cross Fused Attention Forward Primitive + if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + assert bias is None + bias = jnp.zeros(0, dtype=qkv.dtype) + dummy_input = jnp.zeros(0, dtype=qkv.dtype) + dqkv, *_, dbias = FusedAttnBwdPrimitive.outer_primitive.bind( + qkv, + dummy_input, + dummy_input, + bias, + softmax_aux, + rng_state, + output, + doutput, + seqlen, + seqlen, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=NVTE_QKV_Layout.NVTE_BS3HD, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + return dqkv, dbias + + +def fused_attn_fwd_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, + q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray, seed: jnp.ndarray, + attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, + scaling_factor: float, dropout_probability: float, is_training: bool): + """ + Wrapper for TE fused attention fwd with kvpacked inputs + Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 """ - name = "te_cross_fused_attn_forward" - multiple_results = True - impl_static_args = (6, 7, 8, 9, 10) - inner_primitive = None - outer_primitive = None + checker = _FusedAttnRNGStateChecker() + seed = checker.check_seed(seed, dropout_probability, is_training) - @staticmethod - def abstract(q_aval, kv_aval, bias_aval, q_seqlen_or_cu_seqlen_aval, - kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type, - scaling_factor, dropout_probability, is_training): - """ - Cross fused attention fwd abstract - """ - q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) - kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype) - bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) - assert q_dtype == kv_dtype == bias_dtype - assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype + if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + assert bias is None + bias = jnp.zeros(0, dtype=q.dtype) - *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape - *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = kv_aval.shape - assert q_batch_shape == kv_batch_shape - assert q_head_dim == kv_head_dim - assert nkv == 2 - out_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) + return FusedAttnFwdPrimitive.outer_primitive.bind(q, + kv, + jnp.zeros(0, q.dtype), + bias, + q_seqlen, + kv_seqlen, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) - # backend determines the softmax buffer shape/dtype - backend = FusedAttnHelper(q_dtype, kv_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD, - attn_bias_type, attn_mask_type, dropout_probability, attn_heads, - num_gqa_groups, q_max_seqlen, kv_max_seqlen, - q_head_dim).get_fused_attn_backend() - if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: - softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) - softmax_dtype = q_dtype - elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, 1) - softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) - else: - raise ValueError(f'Unsupported {backend=}') - softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) - - # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with - # 32-bit unsigned int to get the buffer size we need in the C++ kernel - checker = _FusedAttnRNGStateChecker() - seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype) - assert seed_dtype == checker.rng_state_dtype - rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) - rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - - # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to - # prepare for the active fused-attn backend - input_batch = reduce(operator.mul, q_batch_shape) - wkspace_info = transformer_engine_jax.get_cross_fused_attn_fwd_workspace_sizes( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, q_head_dim, - scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), is_training) - wkspace_aval = q_aval.update(shape=wkspace_info[0], - dtype=te_dtype_to_jax_dtype(wkspace_info[1])) - - return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval - - @staticmethod - def outer_abstract(*args, **kwargs): - """ - Cross fused attention fwd outer primitive abstract - """ - out_aval, softmax_aux_aval, rng_state_aval, _ = \ - CrossFusedAttnFwdPrimitive.abstract(*args, **kwargs) - return out_aval, softmax_aux_aval, rng_state_aval - - @staticmethod - def lowering(ctx, q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type, - attn_mask_type, scaling_factor, dropout_probability, is_training): - """ - Cross fused attention fwd lowering rules - """ - operands = [q, kv, bias, q_cu_seqlen, kv_cu_seqlen, seed] - operand_shapes = map(lambda x: x.type.shape, operands) - out_types = [ - ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) - for output in ctx.avals_out - ] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - q_aval, kv_aval, bias_aval, *_ = ctx.avals_in - *input_batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape - *_, kv_max_seqlen, _, num_gqa_groups, _ = kv_aval.shape - input_batch = reduce(operator.mul, input_batch_shape) - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - - wkspace_aval = ctx.avals_out[-1] - - opaque = transformer_engine_jax.pack_fused_attn_descriptor( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, head_dim, - wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - is_training) - - out = custom_caller(CrossFusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) - - return out - - @staticmethod - def impl(q, kv, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): - assert CrossFusedAttnFwdPrimitive.inner_primitive is not None - - q_cu_seqlen = generate_cu_seqlen(q_seqlen) - kv_cu_seqlen = generate_cu_seqlen(kv_seqlen) - - output, softmax_aux, rng_state, _ = CrossFusedAttnFwdPrimitive.inner_primitive.bind( - q, - kv, - bias, - q_cu_seqlen, - kv_cu_seqlen, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - return output, softmax_aux, rng_state - - @staticmethod - def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): - _check_valid_batch_dims(batch_dims) - assert CrossFusedAttnFwdPrimitive.outer_primitive is not None - q_bdim, *_, seed_bdim = batch_dims - - out_bdims = q_bdim, q_bdim, seed_bdim - return CrossFusedAttnFwdPrimitive.outer_primitive.bind( - *batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training), out_bdims - - @staticmethod - def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training, mesh, arg_infos, - result_infos): - del attn_bias_type, attn_mask_type, scaling_factor - del dropout_probability, is_training, result_infos - q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden) - kv_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, 2, head, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4])) - rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) - return (out_sharding, softmax_aux_sharding, rng_state_sharding) - - @staticmethod - def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, - mesh, arg_infos, result_infos): - del result_infos - q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden) - kv_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, 2, head, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], kv_spec[-4])) - rng_state_sharding = seed_sharding = NamedSharding(mesh, - PartitionSpec(get_all_mesh_axes(), None)) - arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) - out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - impl = partial(CrossFusedAttnFwdPrimitive.impl, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - return mesh, impl, out_shardings, arg_shardings - - -register_primitive(CrossFusedAttnFwdPrimitive) - - -def cross_fused_attn_fwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, q_seqlen: jnp.ndarray, - kv_seqlen: jnp.ndarray, seed: jnp.ndarray, attn_bias_type: NVTE_Bias_Type, - attn_mask_type: NVTE_Mask_Type, scaling_factor: float, - dropout_probability: float, is_training: bool): - """ - Wrapper for TE cross fused attention fwd - Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2 - """ - checker = _FusedAttnRNGStateChecker() - seed = checker.check_seed(seed, dropout_probability, is_training) - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - assert bias is None - bias = jnp.zeros(0, dtype=q.dtype) - - return CrossFusedAttnFwdPrimitive.outer_primitive.bind(q, - kv, - bias, - q_seqlen, - kv_seqlen, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - - -class CrossFusedAttnBwdPrimitive(BasePrimitive): - """ - Cross Fused Attention Backward Primitive - """ - name = "te_cross_fused_attn_backward" - multiple_results = True - impl_static_args = (9, 10, 11, 12, 13) - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract(q_aval, kv_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, - doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type, - attn_mask_type, scaling_factor, dropout_probability, is_training): - """ - Cross fused attention bwd abstract - """ - del softmax_aux_aval, rng_state_aval, output_aval - - q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) - kv_dtype = dtypes.canonicalize_dtype(kv_aval.dtype) - bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) - doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype) - assert q_dtype == kv_dtype == bias_dtype == doutput_dtype - assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype - - *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape - *kv_batch_shape, kv_max_seqlen, nkv, num_gqa_groups, kv_head_dim = kv_aval.shape - assert q_batch_shape == kv_batch_shape - assert q_head_dim == kv_head_dim - assert nkv == 2 - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - - input_batch = reduce(operator.mul, q_batch_shape) - wkspace_shape, wkspace_dtype = \ - transformer_engine_jax.get_cross_fused_attn_bwd_workspace_sizes( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, q_head_dim, - scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), is_training - ) - - dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) - dkv_aval = kv_aval.update(shape=kv_aval.shape, dtype=kv_dtype) - dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) - wkspace_aval = q_aval.update(shape=wkspace_shape, - dtype=te_dtype_to_jax_dtype(wkspace_dtype)) - - return dq_aval, dkv_aval, dbias_aval, wkspace_aval - - @staticmethod - def outer_abstract(*args, **kwargs): - """ - Cross fused attention fwd outer primitive abstract - """ - dq_aval, dkv_aval, dbias_aval, _ = \ - CrossFusedAttnBwdPrimitive.abstract(*args, **kwargs) - return dq_aval, dkv_aval, dbias_aval - - @staticmethod - def lowering(ctx, q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, - kv_cu_seqlen, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): - """ - Cross fused attention bwd lowering rules - """ - operands = [q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen] - operand_shapes = map(lambda x: x.type.shape, operands) - out_types = [ - ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) - for output in ctx.avals_out - ] - - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - q_aval, kv_aval, bias_aval, *_ = ctx.avals_in - *input_batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape - *_, kv_max_seqlen, _, num_gqa_groups, _ = kv_aval.shape - input_batch = reduce(operator.mul, input_batch_shape) - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - - wkspace_aval = ctx.avals_out[-1] - - opaque = transformer_engine_jax.pack_fused_attn_descriptor( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, head_dim, - wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - is_training) - - out = custom_caller(CrossFusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) - - return out - - @staticmethod - def impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen, - attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training): - assert CrossFusedAttnBwdPrimitive.inner_primitive is not None - - q_cu_seqlen = generate_cu_seqlen(q_seqlen) - kv_cu_seqlen = generate_cu_seqlen(kv_seqlen) - - dq, dkv, dbias, _ = CrossFusedAttnBwdPrimitive.inner_primitive.bind( - q, - kv, - bias, - softmax_aux, - rng_state, - output, - doutput, - q_cu_seqlen, - kv_cu_seqlen, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - return dq, dkv, dbias - - @staticmethod - def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): - _check_valid_batch_dims(batch_dims) - assert CrossFusedAttnBwdPrimitive.outer_primitive is not None - q_bdim, kv_bdim, *_ = batch_dims - - out_bdims = q_bdim, kv_bdim, q_bdim - return CrossFusedAttnBwdPrimitive.outer_primitive.bind( - *batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training), out_bdims - - @staticmethod - def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training, mesh, arg_infos, - result_infos): - del attn_bias_type, attn_mask_type, scaling_factor - del dropout_probability, is_training, result_infos - q_spec = get_padded_spec(arg_infos[0]) - kv_spec = get_padded_spec(arg_infos[1]) - bias_spec = get_padded_spec(arg_infos[2]) - dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec)) - dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - return (dq_sharding, dkv_sharding, dbias_sharding) - - @staticmethod - def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, - mesh, arg_infos, result_infos): - del result_infos - q_spec = get_padded_spec(arg_infos[0]) - kv_spec = get_padded_spec(arg_infos[1]) - bias_spec = get_padded_spec(arg_infos[2]) - dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - dkv_sharding = NamedSharding(mesh, PartitionSpec(*kv_spec)) - dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = (dq_sharding, dkv_sharding, dbias_sharding) - - def sharded_impl(q, kv, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, - kv_cu_seqlen): - local_dq, local_dkv, local_dbias = CrossFusedAttnBwdPrimitive.impl( - q, - kv, - bias, - softmax_aux, - rng_state, - output, - doutput, - q_cu_seqlen, - kv_cu_seqlen, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - global_dbias = local_dbias - if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: - global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) - return local_dq, local_dkv, global_dbias - - return mesh, sharded_impl, out_shardings, arg_shardings - - -register_primitive(CrossFusedAttnBwdPrimitive) - - -def cross_fused_attn_bwd(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, - softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray, - doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray, - attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, - scaling_factor: float, dropout_probability: float, is_training: bool): - """ - Wrapper for TE cross fused attention bwd - Return the gradients of cross fused attention with packed kv input - """ - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - assert bias is None - bias = jnp.zeros(0, dtype=q.dtype) - - return CrossFusedAttnBwdPrimitive.outer_primitive.bind(q, - kv, - bias, - softmax_aux, - rng_state, - output, - doutput, - q_seqlen, - kv_seqlen, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - - -class FusedAttnFwdPrimitive(BasePrimitive): - """ - Fused Attention Forward Primitive - Query, key, value are seperated tensors - """ - name = "te_fused_attn_forward" - multiple_results = True - impl_static_args = (7, 8, 9, 10, 11) - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract(q_aval, k_aval, v_aval, bias_aval, q_seqlen_or_cu_seqlen_aval, - kv_seqlen_or_cu_seqlen_aval, seed_aval, *, attn_bias_type, attn_mask_type, - scaling_factor, dropout_probability, is_training): - """ - Fused attention fwd abstract - """ - q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) - k_dtype = dtypes.canonicalize_dtype(k_aval.dtype) - v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) - bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) - assert q_dtype == k_dtype == v_dtype == bias_dtype - assert q_seqlen_or_cu_seqlen_aval.dtype == kv_seqlen_or_cu_seqlen_aval.dtype - - *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape - *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape - assert q_batch_shape == kv_batch_shape - assert q_head_dim == kv_head_dim - assert k_aval.shape == v_aval.shape - out_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) - - # backend determines the softmax buffer shape/dtype - backend = FusedAttnHelper(q_dtype, k_dtype, NVTE_QKV_Layout.NVTE_BSHD_BS2HD, attn_bias_type, - attn_mask_type, dropout_probability, attn_heads, num_gqa_groups, - q_max_seqlen, kv_max_seqlen, q_head_dim).get_fused_attn_backend() - - if backend == NVTE_Fused_Attn_Backend.NVTE_F16_max512_seqlen: - softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, kv_max_seqlen) - softmax_dtype = q_dtype - elif backend == NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: - softmax_shape = (*q_batch_shape, attn_heads, q_max_seqlen, 1) - softmax_dtype = dtypes.canonicalize_dtype(jnp.float32) - else: - raise ValueError(f'Unsupported {backend=}') - softmax_aux_aval = q_aval.update(shape=softmax_shape, dtype=softmax_dtype) - - # JAX does not enable 64-bit int by default so we get XLA to allocate x8 memory with - # 32-bit unsigned int to get the buffer size we need in the C++ kernel - checker = _FusedAttnRNGStateChecker() - seed_dtype = dtypes.canonicalize_dtype(seed_aval.dtype) - assert seed_dtype == checker.rng_state_dtype - rng_state_shape = (seed_aval.shape[0], checker.rng_state_size) - rng_state_aval = seed_aval.update(shape=rng_state_shape, dtype=checker.rng_state_dtype) - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - - # do a dummy kernel call here to get workspace buffer shapes/dtypes that XLA needs to - # prepare for the active fused-attn backend - input_batch = reduce(operator.mul, q_batch_shape) - wkspace_info = transformer_engine_jax.get_fused_attn_fwd_workspace_sizes( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, q_head_dim, - scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), is_training) - wkspace_aval = q_aval.update(shape=wkspace_info[0], - dtype=te_dtype_to_jax_dtype(wkspace_info[1])) - - return out_aval, softmax_aux_aval, rng_state_aval, wkspace_aval - - @staticmethod - def outer_abstract(*args, **kwargs): - """ - Fused attention fwd outer primitive abstract - """ - out_aval, softmax_aux_aval, rng_state_aval, _ = \ - FusedAttnFwdPrimitive.abstract(*args, **kwargs) - return out_aval, softmax_aux_aval, rng_state_aval - - @staticmethod - def lowering(ctx, q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed, *, attn_bias_type, - attn_mask_type, scaling_factor, dropout_probability, is_training): - """ - Fused attention fwd lowering rules - """ - operands = [q, k, v, bias, q_cu_seqlen, kv_cu_seqlen, seed] - operand_shapes = map(lambda x: x.type.shape, operands) - out_types = [ - ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) - for output in ctx.avals_out - ] - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in - *batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape - *_, kv_max_seqlen, num_gqa_groups, _ = k_aval.shape - assert k_aval.shape == v_aval.shape - input_batch = reduce(operator.mul, batch_shape) - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - - wkspace_aval = ctx.avals_out[-1] - - opaque = transformer_engine_jax.pack_fused_attn_descriptor( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, head_dim, - wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - is_training) - - out = custom_caller(FusedAttnFwdPrimitive.name, args, opaque, has_side_effect=False) - - return out - - @staticmethod - def impl(q, k, v, bias, q_seqlen, kv_seqlen, seed, attn_bias_type, attn_mask_type, - scaling_factor, dropout_probability, is_training): - assert FusedAttnFwdPrimitive.inner_primitive is not None - - q_cu_seqlen = generate_cu_seqlen(q_seqlen) - kv_cu_seqlen = generate_cu_seqlen(kv_seqlen) - - output, softmax_aux, rng_state, _ = FusedAttnFwdPrimitive.inner_primitive.bind( - q, - k, - v, - bias, - q_cu_seqlen, - kv_cu_seqlen, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - return output, softmax_aux, rng_state - - @staticmethod - def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): - _check_valid_batch_dims(batch_dims) - assert FusedAttnFwdPrimitive.outer_primitive is not None - q_bdim, *_, seed_bdim = batch_dims - - out_bdims = q_bdim, q_bdim, seed_bdim - return FusedAttnFwdPrimitive.outer_primitive.bind(*batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training), out_bdims - - @staticmethod - def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training, mesh, arg_infos, - result_infos): - del attn_bias_type, attn_mask_type, scaling_factor - del dropout_probability, is_training, result_infos - q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden) - k_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, head, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3])) - rng_state_sharding = NamedSharding(mesh, PartitionSpec(get_all_mesh_axes(), None)) - return (out_sharding, softmax_aux_sharding, rng_state_sharding) - - @staticmethod - def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, - mesh, arg_infos, result_infos): - del result_infos - q_spec = get_padded_spec(arg_infos[0]) # (...batch, q_seqlen, head, hidden) - k_spec = get_padded_spec(arg_infos[1]) # (...batch, kv_seqlen, head, hidden) - out_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - softmax_aux_sharding = NamedSharding( - mesh, PartitionSpec(*q_spec[:-3], q_spec[-2], q_spec[-3], k_spec[-3])) - rng_state_sharding = seed_sharding = NamedSharding(mesh, - PartitionSpec(get_all_mesh_axes(), None)) - arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) - out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - impl = partial(FusedAttnFwdPrimitive.impl, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - return mesh, impl, out_shardings, arg_shardings - - -register_primitive(FusedAttnFwdPrimitive) +def fused_attn_bwd_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, + softmax_aux: jnp.ndarray, rng_state: jnp.ndarray, output: jnp.ndarray, + doutput: jnp.ndarray, q_seqlen: jnp.ndarray, kv_seqlen: jnp.ndarray, + attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, + scaling_factor: float, dropout_probability: float, is_training: bool): + """ + Wrapper for TE fused attention bwd with kvpacked inputs + Return the gradients of fused attention with packed kv input + """ + if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: + assert bias is None + bias = jnp.zeros(0, dtype=q.dtype) + dummy_input = jnp.zeros(0, q.dtype) + dq, dkv, _, dbias = FusedAttnBwdPrimitive.outer_primitive.bind( + q, + kv, + dummy_input, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BS2HD, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) + return dq, dkv, dbias def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, @@ -3006,7 +2504,7 @@ def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda scaling_factor: float, dropout_probability: float, is_training: bool): """ Wrapper for TE fused attention fwd, where query, key, value are seperated tensors - Return BMM1 -> (PreScaleBias) -> Scale -> (PostScaleBias) -> Softmax -> (Dropout) -> BMM2 + Return BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 """ checker = _FusedAttnRNGStateChecker() seed = checker.check_seed(seed, dropout_probability, is_training) @@ -3015,228 +2513,20 @@ def fused_attn_fwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda assert bias is None bias = jnp.zeros(0, dtype=q.dtype) - return FusedAttnFwdPrimitive.outer_primitive.bind(q, - k, - v, - bias, - q_seqlen, - kv_seqlen, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - - -class FusedAttnBwdPrimitive(BasePrimitive): - """ - Fused Attention Backward Primitive - """ - name = "te_fused_attn_backward" - multiple_results = True - impl_static_args = (10, 11, 12, 13, 14) - inner_primitive = None - outer_primitive = None - - @staticmethod - def abstract(q_aval, k_aval, v_aval, bias_aval, softmax_aux_aval, rng_state_aval, output_aval, - doutput_aval, q_cu_seqlen_aval, kv_cu_seqlen_aval, *, attn_bias_type, - attn_mask_type, scaling_factor, dropout_probability, is_training): - """ - Fused attention bwd abstract - """ - del softmax_aux_aval, rng_state_aval, output_aval - - q_dtype = dtypes.canonicalize_dtype(q_aval.dtype) - k_dtype = dtypes.canonicalize_dtype(k_aval.dtype) - v_dtype = dtypes.canonicalize_dtype(v_aval.dtype) - bias_dtype = dtypes.canonicalize_dtype(bias_aval.dtype) - doutput_dtype = dtypes.canonicalize_dtype(doutput_aval.dtype) - assert q_dtype == k_dtype == v_dtype == bias_dtype == doutput_dtype - assert q_cu_seqlen_aval.dtype == kv_cu_seqlen_aval.dtype - - *q_batch_shape, q_max_seqlen, attn_heads, q_head_dim = q_aval.shape - *kv_batch_shape, kv_max_seqlen, num_gqa_groups, kv_head_dim = k_aval.shape - assert q_batch_shape == kv_batch_shape - assert q_head_dim == kv_head_dim - assert k_aval.shape == v_aval.shape - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - - input_batch = reduce(operator.mul, q_batch_shape) - wkspace_shape, wkspace_dtype = \ - transformer_engine_jax.get_fused_attn_bwd_workspace_sizes( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, q_head_dim, - scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), is_training - ) - - dq_aval = q_aval.update(shape=q_aval.shape, dtype=q_dtype) - dk_aval = k_aval.update(shape=k_aval.shape, dtype=k_dtype) - dv_aval = v_aval.update(shape=v_aval.shape, dtype=v_dtype) - dbias_aval = bias_aval.update(shape=bias_aval.shape, dtype=bias_dtype) - wkspace_aval = q_aval.update(shape=wkspace_shape, - dtype=te_dtype_to_jax_dtype(wkspace_dtype)) - - return dq_aval, dk_aval, dv_aval, dbias_aval, wkspace_aval - - @staticmethod - def outer_abstract(*args, **kwargs): - """ - Fused attention fwd outer primitive abstract - """ - dq_aval, dk_aval, dv_aval, dbias_aval, _ = \ - FusedAttnBwdPrimitive.abstract(*args, **kwargs) - return dq_aval, dk_aval, dv_aval, dbias_aval - - @staticmethod - def lowering(ctx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, - kv_cu_seqlen, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): - """ - Fused attention bwd lowering rules - """ - operands = [ - q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, kv_cu_seqlen - ] - operand_shapes = map(lambda x: x.type.shape, operands) - out_types = [ - ir.RankedTensorType.get(output.shape, mlir.dtype_to_ir_type(output.dtype)) - for output in ctx.avals_out - ] - - args = CustomCallArgsWrapper(out_types, operands, operand_shapes) - - q_aval, k_aval, v_aval, bias_aval, *_ = ctx.avals_in - *batch_shape, q_max_seqlen, attn_heads, head_dim = q_aval.shape - *_, kv_max_seqlen, num_gqa_groups, _ = k_aval.shape - assert k_aval.shape == v_aval.shape - input_batch = reduce(operator.mul, batch_shape) - - if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: - bias_batch = bias_heads = 0 - else: - *bias_batch_shape, bias_heads, _, _ = bias_aval.shape - bias_batch = reduce(operator.mul, bias_batch_shape) - - wkspace_aval = ctx.avals_out[-1] - - opaque = transformer_engine_jax.pack_fused_attn_descriptor( - input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, head_dim, - wkspace_aval.size, scaling_factor, dropout_probability, attn_bias_type, attn_mask_type, - jax_dtype_to_te_dtype(q_aval.dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype), - is_training) - - out = custom_caller(FusedAttnBwdPrimitive.name, args, opaque, has_side_effect=False) - - return out - - @staticmethod - def impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen, - attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training): - assert FusedAttnBwdPrimitive.inner_primitive is not None - - q_cu_seqlen = generate_cu_seqlen(q_seqlen) - kv_cu_seqlen = generate_cu_seqlen(kv_seqlen) - - dq, dk, dv, dbias, _ = FusedAttnBwdPrimitive.inner_primitive.bind( - q, - k, - v, - bias, - softmax_aux, - rng_state, - output, - doutput, - q_cu_seqlen, - kv_cu_seqlen, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - return dq, dk, dv, dbias - - @staticmethod - def batcher(batched_args, batch_dims, *, attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training): - _check_valid_batch_dims(batch_dims) - assert FusedAttnBwdPrimitive.outer_primitive is not None - q_bdim, k_bdim, v_bdim, *_ = batch_dims - - out_bdims = q_bdim, k_bdim, v_bdim, q_bdim - return FusedAttnBwdPrimitive.outer_primitive.bind(*batched_args, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training), out_bdims - - @staticmethod - def infer_sharding_from_operands(attn_bias_type, attn_mask_type, scaling_factor, - dropout_probability, is_training, mesh, arg_infos, - result_infos): - del attn_bias_type, attn_mask_type, scaling_factor - del dropout_probability, is_training, result_infos - q_spec = get_padded_spec(arg_infos[0]) - k_spec = get_padded_spec(arg_infos[1]) - v_spec = get_padded_spec(arg_infos[2]) - bias_spec = get_padded_spec(arg_infos[3]) - dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) - dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) - dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - return (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) - - @staticmethod - def partition(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training, - mesh, arg_infos, result_infos): - del result_infos - q_spec = get_padded_spec(arg_infos[0]) - k_spec = get_padded_spec(arg_infos[1]) - v_spec = get_padded_spec(arg_infos[2]) - bias_spec = get_padded_spec(arg_infos[3]) - dq_sharding = NamedSharding(mesh, PartitionSpec(*q_spec)) - dk_sharding = NamedSharding(mesh, PartitionSpec(*k_spec)) - dv_sharding = NamedSharding(mesh, PartitionSpec(*v_spec)) - dbias_sharding = NamedSharding(mesh, PartitionSpec(*bias_spec)) - arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) - out_shardings = (dq_sharding, dk_sharding, dv_sharding, dbias_sharding) - - def sharded_impl(q, k, v, bias, softmax_aux, rng_state, output, doutput, q_cu_seqlen, - kv_cu_seqlen): - local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl( - q, - k, - v, - bias, - softmax_aux, - rng_state, - output, - doutput, - q_cu_seqlen, - kv_cu_seqlen, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) - global_dbias = local_dbias - if attn_bias_type is not NVTE_Bias_Type.NVTE_NO_BIAS: - global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) - return local_dq, local_dk, local_dv, global_dbias - - return mesh, sharded_impl, out_shardings, arg_shardings - - -register_primitive(FusedAttnBwdPrimitive) + return FusedAttnFwdPrimitive.outer_primitive.bind( + q, + k, + v, + bias, + q_seqlen, + kv_seqlen, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, @@ -3251,22 +2541,23 @@ def fused_attn_bwd(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.nda if attn_bias_type == NVTE_Bias_Type.NVTE_NO_BIAS: assert bias is None bias = jnp.zeros(0, dtype=q.dtype) - - return FusedAttnBwdPrimitive.outer_primitive.bind(q, - k, - v, - bias, - softmax_aux, - rng_state, - output, - doutput, - q_seqlen, - kv_seqlen, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + return FusedAttnBwdPrimitive.outer_primitive.bind( + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=NVTE_QKV_Layout.NVTE_BSHD_BSHD_BSHD, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) class GeluPrimitive(BasePrimitive): diff --git a/transformer_engine/jax/csrc/extensions.cpp b/transformer_engine/jax/csrc/extensions.cpp index 5faec6fd10..5e4ab4f205 100644 --- a/transformer_engine/jax/csrc/extensions.cpp +++ b/transformer_engine/jax/csrc/extensions.cpp @@ -49,10 +49,6 @@ pybind11::dict Registrations() { EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxForward); dict["te_scaled_upper_triang_masked_softmax_backward"] = EncapsulateFunction(ScaledUpperTriangMaskedSoftmaxBackward); - dict["te_self_fused_attn_forward"] = EncapsulateFunction(SelfFusedAttnForward); - dict["te_self_fused_attn_backward"] = EncapsulateFunction(SelfFusedAttnBackward); - dict["te_cross_fused_attn_forward"] = EncapsulateFunction(CrossFusedAttnForward); - dict["te_cross_fused_attn_backward"] = EncapsulateFunction(CrossFusedAttnBackward); dict["te_fused_attn_forward"] = EncapsulateFunction(FusedAttnForward); dict["te_fused_attn_backward"] = EncapsulateFunction(FusedAttnBackward); return dict; @@ -72,10 +68,6 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes); m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes); m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes); - m.def("get_self_fused_attn_fwd_workspace_sizes", &GetSelfFusedAttnForwardWorkspaceSizes); - m.def("get_self_fused_attn_bwd_workspace_sizes", &GetSelfFusedAttnBackwardWorkspaceSizes); - m.def("get_cross_fused_attn_fwd_workspace_sizes", &GetCrossFusedAttnForwardWorkspaceSizes); - m.def("get_cross_fused_attn_bwd_workspace_sizes", &GetCrossFusedAttnBackwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); m.def("get_fused_attn_bwd_workspace_sizes", &GetFusedAttnBackwardWorkspaceSizes); diff --git a/transformer_engine/jax/csrc/modules.cpp b/transformer_engine/jax/csrc/modules.cpp index ceacc7c90f..1c4c468d51 100644 --- a/transformer_engine/jax/csrc/modules.cpp +++ b/transformer_engine/jax/csrc/modules.cpp @@ -11,14 +11,12 @@ #include #include -#include -#include #include #include #include #include "common/common.h" -#include "common/util/cuda_runtime.h" +#include "common/util/logging.h" #include "transformer_engine/activation.h" #include "transformer_engine/cast.h" #include "transformer_engine/fused_attn.h" @@ -96,13 +94,13 @@ pybind11::bytes PackCustomCallSoftmaxDescriptor(size_t batch_size, size_t paddin pybind11::bytes PackCustomCallFusedAttnDescriptor( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - size_t wkspace_size, float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - DType dtype, DType wkspace_dtype, bool is_training) { + size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, + bool is_training) { return PackOpaque(CustomCallFusedAttnDescriptor{ input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, - bias_heads, head_dim, wkspace_size, scaling_factor, dropout_probability, - bias_type, mask_type, dtype, wkspace_dtype, is_training}); + bias_heads, head_dim, wkspace_size, scaling_factor, dropout_probability, bias_type, + mask_type, qkv_layout, dtype, wkspace_dtype, is_training}); } void TransposeImpl(void *input, size_t rows, size_t cols, DType dtype, cudaStream_t stream, @@ -942,12 +940,12 @@ void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, float dropout_probability, - size_t q_num_heads, size_t kv_num_heads, + size_t q_attn_heads, size_t kv_attn_heads, size_t q_max_seqlen, size_t kv_max_seqlen, size_t head_dim) { auto backend = nvte_get_fused_attn_backend( static_cast(q_dtype), static_cast(kv_dtype), qkv_layout, bias_type, - mask_type, dropout_probability, q_num_heads, kv_num_heads, q_max_seqlen, kv_max_seqlen, + mask_type, dropout_probability, q_attn_heads, kv_attn_heads, q_max_seqlen, kv_max_seqlen, head_dim); return backend; } @@ -1029,244 +1027,31 @@ void PrepareFusedAttnBackwardAuxTensors(NVTETensorPack *tensor_pack, } } -pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes( - size_t input_batch, size_t bias_batch, size_t max_seqlen, - size_t attn_heads, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) { - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; - - auto qkv_shape = std::vector{input_batch * max_seqlen, 3, attn_heads, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, max_seqlen, max_seqlen}; - - auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype); - auto cu_seqlens_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - auto o_tensor = TensorWrapper( - nullptr, std::vector{input_batch, max_seqlen, attn_heads, head_dim}, dtype); - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); - auto rng_state_tensor = TensorWrapper(nullptr, std::vector{2}, DType::kInt64); - - auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, attn_heads, max_seqlen, max_seqlen, head_dim); - - NVTETensorPack aux_output_tensors; - nvte_tensor_pack_create(&aux_output_tensors); - - TensorWrapper query_workspace_tensor; - nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(), - rng_state_tensor.data(), max_seqlen, is_training, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, - query_workspace_tensor.data(), nullptr); - - auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); - return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); -} - -void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { - const CustomCallFusedAttnDescriptor &descriptor = - *UnpackOpaque(opaque, opaque_len); - - // input buffers from XLA - void *qkv = buffers[0]; - void *bias = buffers[1]; - void *cu_seqlens = buffers[2]; - void *seed = buffers[3]; - - // output buffers from XLA - void *output = buffers[4]; - void *softmax_aux = buffers[5]; - void *rng_state = buffers[6]; - void *workspace = buffers[7]; - - // tensor sizes - auto input_batch = descriptor.input_batch; - auto bias_batch = descriptor.bias_batch; - auto max_seqlen = descriptor.q_max_seqlen; - auto attn_heads = descriptor.attn_heads; - auto bias_heads = descriptor.bias_heads; - auto head_dim = descriptor.head_dim; - auto dropout_probability = descriptor.dropout_probability; - auto bias_type = descriptor.bias_type; - auto mask_type = descriptor.mask_type; - - auto dtype = descriptor.dtype; - auto qkv_shape = std::vector{input_batch * max_seqlen, 3, attn_heads, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, max_seqlen, max_seqlen}; - - // input tensors - auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); - auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); - auto cu_seqlens_tensor = - TensorWrapper(cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); - - // output tensors - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in FP16/BF16 - auto o_tensor = TensorWrapper( - output, std::vector{input_batch * max_seqlen, attn_heads, head_dim}, dtype); - - // prep RNG state - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; - auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); - auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, attn_heads, max_seqlen, max_seqlen, head_dim); - PopulateRngStateAsync(rng_state, seed, max_seqlen, max_seqlen, backend, stream); - - // auxiliary tensors (to be propagated to the backward pass later) - NVTETensorPack aux_output_tensors; - nvte_tensor_pack_create(&aux_output_tensors); - PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend, - softmax_aux); - - // cuDNN workspace - auto wkspace_size = std::vector{descriptor.wkspace_size}; - auto wkspace_dtype = descriptor.wkspace_dtype; - auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); - - nvte_fused_attn_fwd_qkvpacked(qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), - o_tensor.data(), &aux_output_tensors, cu_seqlens_tensor.data(), - rng_state_tensor.data(), max_seqlen, descriptor.is_training, - descriptor.scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, workspace_tensor.data(), stream); - - nvte_tensor_pack_destroy(&aux_output_tensors); -} - -pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes( - size_t input_batch, size_t bias_batch, size_t max_seqlen, - size_t attn_heads, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) { - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; - - auto qkv_shape = std::vector{input_batch * max_seqlen, 3, attn_heads, head_dim}; - auto output_shape = std::vector{input_batch * max_seqlen, attn_heads, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, max_seqlen, max_seqlen}; - - auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - auto output_tensor = TensorWrapper(nullptr, output_shape, dtype); - auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype); - // F16 doesn't use this tensor - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); - - auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); - auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype); - - auto cu_seqlens_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - - NVTETensorPack aux_input_tensors; - nvte_tensor_pack_create(&aux_input_tensors); - - TensorWrapper query_workspace_tensor; - nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - cu_seqlens_tensor.data(), max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, - query_workspace_tensor.data(), nullptr); - - auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); - return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); -} - -void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { - const CustomCallFusedAttnDescriptor &descriptor = - *UnpackOpaque(opaque, opaque_len); - - // input buffers from XLA - void *qkv = buffers[0]; - void *bias = buffers[1]; - void *softmax_aux = buffers[2]; - void *rng_state = buffers[3]; - void *output = buffers[4]; - void *doutput = buffers[5]; - void *cu_seqlens = buffers[6]; - - // output buffers from XLA - void *dqkv = buffers[7]; - void *dbias = buffers[8]; - void *workspace = buffers[9]; - - // tensor sizes - auto input_batch = descriptor.input_batch; - auto bias_batch = descriptor.bias_batch; - auto max_seqlen = descriptor.q_max_seqlen; - auto attn_heads = descriptor.attn_heads; - auto bias_heads = descriptor.bias_heads; - auto head_dim = descriptor.head_dim; - auto scaling_factor = descriptor.scaling_factor; - auto dropout_probability = descriptor.dropout_probability; - auto bias_type = descriptor.bias_type; - auto mask_type = descriptor.mask_type; - - auto dtype = descriptor.dtype; - auto qkv_shape = std::vector{input_batch * max_seqlen, 3, attn_heads, head_dim}; - auto output_shape = std::vector{input_batch * max_seqlen, attn_heads, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, max_seqlen, max_seqlen}; - - // input tensors - auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); - auto output_tensor = TensorWrapper(output, output_shape, dtype); - auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); - - // output tensors - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in FP16/BF16 - auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype); - auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); - auto cu_seqlens_tensor = - TensorWrapper(cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); - - // auxiliary tensors (propagated from the forward pass) - NVTETensorPack aux_input_tensors; - nvte_tensor_pack_create(&aux_input_tensors); - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BS3HD; - auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, attn_heads, max_seqlen, max_seqlen, head_dim); - PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, - rng_state, bias); - - // cuDNN workspace - auto wkspace_size = std::vector{descriptor.wkspace_size}; - auto wkspace_dtype = descriptor.wkspace_dtype; - auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); - - nvte_fused_attn_bwd_qkvpacked(qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), - cu_seqlens_tensor.data(), max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, - workspace_tensor.data(), stream); - - nvte_tensor_pack_destroy(&aux_input_tensors); -} - -pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes( +pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) { - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; + float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) { + // For qkv_packed + auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; + auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); + // For kv_packed auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); + // For separate q, k, v + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; + auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); + auto v_shape = k_shape; + auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); + + auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype); - // FP16/BF16 doesn't use this tensor + // F16 doesn't use this tensor auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); auto o_tensor = TensorWrapper(nullptr, q_shape, dtype); @@ -1281,292 +1066,133 @@ pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes( nvte_tensor_pack_create(&aux_output_tensors); TensorWrapper query_workspace_tensor; - nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, - is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, query_workspace_tensor.data(), nullptr); - - auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); - return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); -} - -void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { - const CustomCallFusedAttnDescriptor &descriptor = - *UnpackOpaque(opaque, opaque_len); - - // input buffers from XLA - void *q = buffers[0]; - void *kv = buffers[1]; - void *bias = buffers[2]; - void *q_cu_seqlens = buffers[3]; - void *kv_cu_seqlens = buffers[4]; - void *seed = buffers[5]; - - // output buffers from XLA - void *output = buffers[6]; - void *softmax_aux = buffers[7]; - void *rng_state = buffers[8]; - void *workspace = buffers[9]; - - // tensor sizes - auto input_batch = descriptor.input_batch; - auto bias_batch = descriptor.bias_batch; - auto q_max_seqlen = descriptor.q_max_seqlen; - auto kv_max_seqlen = descriptor.kv_max_seqlen; - auto attn_heads = descriptor.attn_heads; - auto num_gqa_groups = descriptor.num_gqa_groups; - auto bias_heads = descriptor.bias_heads; - auto head_dim = descriptor.head_dim; - auto scaling_factor = descriptor.scaling_factor; - auto dropout_probability = descriptor.dropout_probability; - auto bias_type = descriptor.bias_type; - auto mask_type = descriptor.mask_type; - - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - - // input tensors - auto dtype = descriptor.dtype; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); - auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); - - // output tensors - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in FP16/BF16 - auto o_tensor = TensorWrapper(output, q_shape, dtype); - auto q_cu_seqlens_tensor = - TensorWrapper(q_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); - auto kv_cu_seqlens_tensor = - TensorWrapper(kv_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); - - // prep RNG state - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; - auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); - auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - head_dim); - PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); - - // auxiliary tensors (to be propagated to the backward pass later) - NVTETensorPack aux_output_tensors; - nvte_tensor_pack_create(&aux_output_tensors); - PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend, - softmax_aux); - - // cuDNN workspace - auto workspace_tensor = TensorWrapper(workspace, std::vector{descriptor.wkspace_size}, - descriptor.wkspace_dtype); - - nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, - descriptor.is_training, scaling_factor, dropout_probability, - qkv_layout, bias_type, mask_type, workspace_tensor.data(), stream); + if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { + assert(q_max_seqlen == kv_max_seqlen); + nvte_fused_attn_fwd_qkvpacked( + qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), + &aux_output_tensors, q_cu_seqlens_tensor.data(), dummy_rng_state_tensor.data(), + q_max_seqlen, is_training, scaling_factor, dropout_probability, qkv_layout, bias_type, + mask_type, query_workspace_tensor.data(), nullptr); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { + nvte_fused_attn_fwd_kvpacked(q_tensor.data(), kv_tensor.data(), bias_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, + is_training, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, query_workspace_tensor.data(), nullptr); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + query_workspace_tensor.data(), nullptr); + } else { + NVTE_ERROR("Unsupported QKVLayout."); + } - nvte_tensor_pack_destroy(&aux_output_tensors); + auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape()); + return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype()); } -pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes( - size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) { - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; - - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - - auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); - auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype); +pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( + size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, + size_t num_gqa_groups, size_t head_dim, float scaling_factor, float dropout_probability, + NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, + bool is_training) { + auto output_shape = std::vector{batch_size * q_max_seqlen, attn_heads, head_dim}; auto output_tensor = TensorWrapper(nullptr, output_shape, dtype); - // FP16/BF16 doesn't use this tensor - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); + auto doutput_tensor = TensorWrapper(nullptr, output_shape, dtype); - auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype); + auto bias_shape = std::vector{1, attn_heads, q_max_seqlen, kv_max_seqlen}; auto dbias_tensor = TensorWrapper(nullptr, bias_shape, dtype); - auto q_cu_seqlens_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - auto kv_cu_seqlens_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - - NVTETensorPack aux_input_tensors; - nvte_tensor_pack_create(&aux_input_tensors); - - TensorWrapper query_workspace_tensor; - nvte_fused_attn_bwd_kvpacked( - q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for FP16/BF16 - s_tensor.data(), // not used for FP16/BF16 - &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - query_workspace_tensor.data(), nullptr); - - auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); - return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); -} - -void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len) { - const CustomCallFusedAttnDescriptor &descriptor = - *UnpackOpaque(opaque, opaque_len); - - // input buffers from XLA - void *q = buffers[0]; - void *kv = buffers[1]; - void *bias = buffers[2]; - void *softmax_aux = buffers[3]; - void *rng_state = buffers[4]; - void *output = buffers[5]; - void *doutput = buffers[6]; - void *q_cu_seqlens = buffers[7]; - void *kv_cu_seqlens = buffers[8]; - - // output buffers from XLA - void *dq = buffers[9]; - void *dkv = buffers[10]; - void *dbias = buffers[11]; - void *workspace = buffers[12]; - - // tensor sizes - auto input_batch = descriptor.input_batch; - auto bias_batch = descriptor.bias_batch; - auto q_max_seqlen = descriptor.q_max_seqlen; - auto kv_max_seqlen = descriptor.kv_max_seqlen; - auto attn_heads = descriptor.attn_heads; - auto num_gqa_groups = descriptor.num_gqa_groups; - auto bias_heads = descriptor.bias_heads; - auto head_dim = descriptor.head_dim; - auto scaling_factor = descriptor.scaling_factor; - auto dropout_probability = descriptor.dropout_probability; - auto bias_type = descriptor.bias_type; - auto mask_type = descriptor.mask_type; - - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto kv_shape = std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; - auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - - // input tensors - auto dtype = descriptor.dtype; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); - auto output_tensor = TensorWrapper(output, output_shape, dtype); - auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); + // F16 doesn't use s_tensor + auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); - // output tensors - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in FP16/BF16 - auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype); - auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); auto q_cu_seqlens_tensor = - TensorWrapper(q_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); + TensorWrapper(nullptr, std::vector{batch_size + 1}, DType::kInt32); auto kv_cu_seqlens_tensor = - TensorWrapper(kv_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); + TensorWrapper(nullptr, std::vector{batch_size + 1}, DType::kInt32); - // auxiliary tensors (propagated from the forward pass) NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BS2HD; - auto backend = nvte_get_fused_attn_backend( - static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, - mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, - head_dim); - PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, - rng_state, bias); - - // cuDNN workspace - auto wkspace_size = std::vector{descriptor.wkspace_size}; - auto wkspace_dtype = descriptor.wkspace_dtype; - auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); - - nvte_fused_attn_bwd_kvpacked( - q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), - s_tensor.data(), // not used for FP16/BF16 - s_tensor.data(), // not used for FP16/BF16 - &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - workspace_tensor.data(), stream); - - nvte_tensor_pack_destroy(&aux_input_tensors); -} - -pybind11::tuple GetFusedAttnForwardWorkspaceSizes( - size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) { - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD; - - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; - auto v_shape = k_shape; - auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - - auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); - auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); - auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); - - auto bias_tensor = TensorWrapper(nullptr, bias_shape, dtype); - - // F16 doesn't use this tensor - auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); - auto o_tensor = TensorWrapper(nullptr, q_shape, dtype); - - auto q_cu_seqlens_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - auto kv_cu_seqlens_tensor = - TensorWrapper(nullptr, std::vector{input_batch + 1}, DType::kInt32); - - auto dummy_rng_state_tensor = TensorWrapper(nullptr, std::vector{2}, DType::kInt64); - - NVTETensorPack aux_output_tensors; - nvte_tensor_pack_create(&aux_output_tensors); TensorWrapper query_workspace_tensor; - nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - dummy_rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, is_training, - scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, - query_workspace_tensor.data(), nullptr); - auto work_shape = MakeShapeVector(query_workspace_tensor.shape()); - return pybind11::make_tuple(work_shape, query_workspace_tensor.dtype()); + if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { + assert(q_max_seqlen == kv_max_seqlen); + auto qkv_shape = std::vector{batch_size * q_max_seqlen, 3, attn_heads, head_dim}; + auto qkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); + auto dqkv_tensor = TensorWrapper(nullptr, qkv_shape, dtype); + nvte_fused_attn_bwd_qkvpacked( + qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), + q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + query_workspace_tensor.data(), nullptr); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { + auto q_shape = std::vector{batch_size * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto kv_shape = + std::vector{batch_size * kv_max_seqlen, 2, num_gqa_groups, head_dim}; + auto kv_tensor = TensorWrapper(nullptr, kv_shape, dtype); + auto dkv_tensor = TensorWrapper(nullptr, kv_shape, dtype); + nvte_fused_attn_bwd_kvpacked( + q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + query_workspace_tensor.data(), nullptr); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + auto q_shape = std::vector{batch_size * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto dq_tensor = TensorWrapper(nullptr, q_shape, dtype); + auto k_shape = std::vector{batch_size * kv_max_seqlen, num_gqa_groups, head_dim}; + auto k_tensor = TensorWrapper(nullptr, k_shape, dtype); + auto dk_tensor = TensorWrapper(nullptr, k_shape, dtype); + auto v_shape = k_shape; + auto v_tensor = TensorWrapper(nullptr, v_shape, dtype); + auto dv_tensor = TensorWrapper(nullptr, v_shape, dtype); + nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), + dv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + query_workspace_tensor.data(), nullptr); + } else { + NVTE_ERROR("Unsupported QKVLayout."); + } + + auto workspace_shape = MakeShapeVector(query_workspace_tensor.shape()); + return pybind11::make_tuple(workspace_shape, query_workspace_tensor.dtype()); } void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len) { const CustomCallFusedAttnDescriptor &descriptor = *UnpackOpaque(opaque, opaque_len); - // input buffers from XLA - void *q = buffers[0]; - void *k = buffers[1]; - void *v = buffers[2]; + /* Input buffers from XLA */ + /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ void *bias = buffers[3]; void *q_cu_seqlens = buffers[4]; void *kv_cu_seqlens = buffers[5]; void *seed = buffers[6]; - // output buffers from XLA + /* Output buffer from XLA */ void *output = buffers[7]; void *softmax_aux = buffers[8]; void *rng_state = buffers[9]; void *workspace = buffers[10]; - // tensor sizes + /* Descriptor */ auto input_batch = descriptor.input_batch; auto bias_batch = descriptor.bias_batch; auto q_max_seqlen = descriptor.q_max_seqlen; @@ -1579,29 +1205,26 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s auto dropout_probability = descriptor.dropout_probability; auto bias_type = descriptor.bias_type; auto mask_type = descriptor.mask_type; + auto qkv_layout = descriptor.qkv_layout; + auto dtype = descriptor.dtype; + /* Input tensors */ auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; auto v_shape = k_shape; auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - - // input tensors - auto dtype = descriptor.dtype; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v_tensor = TensorWrapper(v, v_shape, dtype); auto bias_tensor = TensorWrapper(bias, bias_shape, dtype); - // output tensors + /* Output tensors */ auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 - auto o_tensor = TensorWrapper(output, q_shape, dtype); + auto o_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto o_tensor = TensorWrapper(output, o_shape, dtype); auto q_cu_seqlens_tensor = TensorWrapper(q_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); auto kv_cu_seqlens_tensor = TensorWrapper(kv_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); - // prep RNG state - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD; + /* Prepare RNG state */ auto rng_state_tensor = TensorWrapper(rng_state, std::vector{2}, DType::kInt64); auto backend = nvte_get_fused_attn_backend( static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, @@ -1609,22 +1232,59 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s head_dim); PopulateRngStateAsync(rng_state, seed, q_max_seqlen, kv_max_seqlen, backend, stream); - // auxiliary tensors (to be propagated to the backward pass later) + /* Auxiliary tensors (to be propagated to the backward pass later) */ NVTETensorPack aux_output_tensors; nvte_tensor_pack_create(&aux_output_tensors); PrepareFusedAttnForwardAuxTensors(&aux_output_tensors, &descriptor, bias_type, backend, softmax_aux); - // cuDNN workspace + /* cuDNN workspace */ auto workspace_tensor = TensorWrapper(workspace, std::vector{descriptor.wkspace_size}, descriptor.wkspace_dtype); - nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), - s_tensor.data(), o_tensor.data(), &aux_output_tensors, - q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), - rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, - descriptor.is_training, scaling_factor, dropout_probability, qkv_layout, - bias_type, mask_type, workspace_tensor.data(), stream); + /* Call the underly NVTE API */ + if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { + auto qkv = buffers[0]; + auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; + auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); + nvte_fused_attn_fwd_qkvpacked( + qkv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), + &aux_output_tensors, q_cu_seqlens_tensor.data(), rng_state_tensor.data(), q_max_seqlen, + descriptor.is_training, descriptor.scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, workspace_tensor.data(), stream); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { + auto q = buffers[0]; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto kv = buffers[1]; + auto kv_shape = + std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; + auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); + nvte_fused_attn_fwd_kvpacked( + q_tensor.data(), kv_tensor.data(), bias_tensor.data(), s_tensor.data(), o_tensor.data(), + &aux_output_tensors, q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, descriptor.is_training, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + workspace_tensor.data(), stream); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + auto q = buffers[0]; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto k = buffers[1]; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; + auto k_tensor = TensorWrapper(k, k_shape, dtype); + auto v = buffers[2]; + auto v_shape = k_shape; + auto v_tensor = TensorWrapper(v, v_shape, dtype); + nvte_fused_attn_fwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), bias_tensor.data(), + s_tensor.data(), o_tensor.data(), &aux_output_tensors, + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), + rng_state_tensor.data(), q_max_seqlen, kv_max_seqlen, + descriptor.is_training, scaling_factor, dropout_probability, qkv_layout, + bias_type, mask_type, workspace_tensor.data(), stream); + } else { + NVTE_ERROR("Unsupported qkv_layout."); + } nvte_tensor_pack_destroy(&aux_output_tensors); } @@ -1632,10 +1292,8 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training) { - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD; - + float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training) { auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; auto v_shape = k_shape; @@ -1682,10 +1340,8 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, const CustomCallFusedAttnDescriptor &descriptor = *UnpackOpaque(opaque, opaque_len); - // input buffers from XLA - void *q = buffers[0]; - void *k = buffers[1]; - void *v = buffers[2]; + /* Input buffers from XLA */ + /* Buffers[0-2] are q, k, v, which are parsed later for different qkv_layout */ void *bias = buffers[3]; void *softmax_aux = buffers[4]; void *rng_state = buffers[5]; @@ -1694,14 +1350,12 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, void *q_cu_seqlens = buffers[8]; void *kv_cu_seqlens = buffers[9]; - // output buffers from XLA - void *dq = buffers[10]; - void *dk = buffers[11]; - void *dv = buffers[12]; + /* Output buffer from XLA */ + /* Buffers[10-12] are dq, dk, dv, which are parsed later for different qkv_layout */ void *dbias = buffers[13]; void *workspace = buffers[14]; - // tensor sizes + /* Descriptor */ auto input_batch = descriptor.input_batch; auto bias_batch = descriptor.bias_batch; auto q_max_seqlen = descriptor.q_max_seqlen; @@ -1714,36 +1368,26 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, auto dropout_probability = descriptor.dropout_probability; auto bias_type = descriptor.bias_type; auto mask_type = descriptor.mask_type; + auto qkv_layout = descriptor.qkv_layout; + auto dtype = descriptor.dtype; - auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; - auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; - auto v_shape = k_shape; + /* Input tensors */ auto output_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; auto bias_shape = std::vector{bias_batch, bias_heads, q_max_seqlen, kv_max_seqlen}; - - // input tensors - auto dtype = descriptor.dtype; - auto q_tensor = TensorWrapper(q, q_shape, dtype); - auto k_tensor = TensorWrapper(k, k_shape, dtype); - auto v_tensor = TensorWrapper(v, v_shape, dtype); auto output_tensor = TensorWrapper(output, output_shape, dtype); auto doutput_tensor = TensorWrapper(doutput, output_shape, dtype); - // output tensors + /* Output tensors */ auto s_tensor = TensorWrapper(nullptr, std::vector{1}, dtype); // not used in F16 - auto dq_tensor = TensorWrapper(dq, q_shape, dtype); - auto dk_tensor = TensorWrapper(dk, k_shape, dtype); - auto dv_tensor = TensorWrapper(dv, v_shape, dtype); auto dbias_tensor = TensorWrapper(dbias, bias_shape, dtype); auto q_cu_seqlens_tensor = TensorWrapper(q_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); auto kv_cu_seqlens_tensor = TensorWrapper(kv_cu_seqlens, std::vector{input_batch + 1}, DType::kInt32); - // auxiliary tensors (propagated from the forward pass) + /* Auxiliary tensors (propagated from the forward pass) */ NVTETensorPack aux_input_tensors; nvte_tensor_pack_create(&aux_input_tensors); - constexpr auto qkv_layout = NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD; auto backend = nvte_get_fused_attn_backend( static_cast(dtype), static_cast(dtype), qkv_layout, bias_type, mask_type, dropout_probability, attn_heads, num_gqa_groups, q_max_seqlen, kv_max_seqlen, @@ -1751,20 +1395,73 @@ void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, PrepareFusedAttnBackwardAuxTensors(&aux_input_tensors, &descriptor, backend, softmax_aux, rng_state, bias); - // cuDNN workspace + /* cuDNN workspace */ auto wkspace_size = std::vector{descriptor.wkspace_size}; auto wkspace_dtype = descriptor.wkspace_dtype; auto workspace_tensor = TensorWrapper(workspace, wkspace_size, wkspace_dtype); - nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), - doutput_tensor.data(), - s_tensor.data(), // not used for F16 - s_tensor.data(), // not used for F16 - &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), dv_tensor.data(), - dbias_tensor.data(), q_cu_seqlens_tensor.data(), - kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, scaling_factor, - dropout_probability, qkv_layout, bias_type, mask_type, - workspace_tensor.data(), stream); + /* Call the underly NVTE API */ + if (qkv_layout == NVTE_QKV_Layout::NVTE_BS3HD) { + auto qkv = buffers[0]; + auto qkv_shape = std::vector{input_batch * q_max_seqlen, 3, attn_heads, head_dim}; + auto qkv_tensor = TensorWrapper(qkv, qkv_shape, dtype); + auto dqkv = buffers[10]; + auto dqkv_tensor = TensorWrapper(dqkv, qkv_shape, dtype); + nvte_fused_attn_bwd_qkvpacked( + qkv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dqkv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), + q_max_seqlen, scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + workspace_tensor.data(), stream); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BS2HD) { + auto q = buffers[0]; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto kv = buffers[1]; + auto kv_shape = + std::vector{input_batch * kv_max_seqlen, 2, num_gqa_groups, head_dim}; + auto kv_tensor = TensorWrapper(kv, kv_shape, dtype); + auto dq = buffers[10]; + auto dq_tensor = TensorWrapper(dq, q_shape, dtype); + auto dkv = buffers[11]; + auto dkv_tensor = TensorWrapper(dkv, kv_shape, dtype); + nvte_fused_attn_bwd_kvpacked( + q_tensor.data(), kv_tensor.data(), output_tensor.data(), doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dkv_tensor.data(), dbias_tensor.data(), + q_cu_seqlens_tensor.data(), kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + workspace_tensor.data(), stream); + } else if (qkv_layout == NVTE_QKV_Layout::NVTE_BSHD_BSHD_BSHD) { + auto q = buffers[0]; + auto q_shape = std::vector{input_batch * q_max_seqlen, attn_heads, head_dim}; + auto q_tensor = TensorWrapper(q, q_shape, dtype); + auto k = buffers[1]; + auto k_shape = std::vector{input_batch * kv_max_seqlen, num_gqa_groups, head_dim}; + auto k_tensor = TensorWrapper(k, k_shape, dtype); + auto v = buffers[2]; + auto v_shape = k_shape; + auto v_tensor = TensorWrapper(v, v_shape, dtype); + auto dq = buffers[10]; + auto dq_tensor = TensorWrapper(dq, q_shape, dtype); + auto dk = buffers[11]; + auto dk_tensor = TensorWrapper(dk, k_shape, dtype); + auto dv = buffers[12]; + auto dv_tensor = TensorWrapper(dv, v_shape, dtype); + nvte_fused_attn_bwd(q_tensor.data(), k_tensor.data(), v_tensor.data(), output_tensor.data(), + doutput_tensor.data(), + s_tensor.data(), // not used for F16 + s_tensor.data(), // not used for F16 + &aux_input_tensors, dq_tensor.data(), dk_tensor.data(), + dv_tensor.data(), dbias_tensor.data(), q_cu_seqlens_tensor.data(), + kv_cu_seqlens_tensor.data(), q_max_seqlen, kv_max_seqlen, + scaling_factor, dropout_probability, qkv_layout, bias_type, mask_type, + workspace_tensor.data(), stream); + } else { + NVTE_ERROR("Unsupported qkv_layout."); + } nvte_tensor_pack_destroy(&aux_input_tensors); } diff --git a/transformer_engine/jax/csrc/modules.h b/transformer_engine/jax/csrc/modules.h index a2b873235e..e392931d04 100644 --- a/transformer_engine/jax/csrc/modules.h +++ b/transformer_engine/jax/csrc/modules.h @@ -118,17 +118,18 @@ struct CustomCallFusedAttnDescriptor { float dropout_probability; NVTE_Bias_Type bias_type; NVTE_Mask_Type mask_type; + NVTE_QKV_Layout qkv_layout; DType dtype; DType wkspace_dtype; bool is_training; }; pybind11::bytes PackCustomCallFusedAttnDescriptor( - size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, + size_t input_batch, size_t batch_size, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - size_t wkspace_size, float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, - DType dtype, DType wkspace_dtype, bool is_training); + size_t wkspace_size, float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, + bool is_training); NVTE_Fused_Attn_Backend GetFusedAttnBackend(DType q_dtype, DType kv_dtype, NVTE_QKV_Layout qkv_layout, NVTE_Bias_Type bias_type, @@ -207,55 +208,19 @@ void ScaledUpperTriangMaskedSoftmaxForward(cudaStream_t stream, void **buffers, void ScaledUpperTriangMaskedSoftmaxBackward(cudaStream_t stream, void **buffers, const char *opaque, std::size_t opaque_len); -pybind11::tuple GetSelfFusedAttnForwardWorkspaceSizes( - size_t input_batch, size_t bias_batch, size_t max_seqlen, - size_t attn_heads, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); - -void SelfFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len); - -pybind11::tuple GetSelfFusedAttnBackwardWorkspaceSizes( - size_t input_batch, size_t bias_batch, size_t max_seqlen, - size_t attn_heads, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); - -void SelfFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len); - -pybind11::tuple GetCrossFusedAttnForwardWorkspaceSizes( - size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); - -void CrossFusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len); - -pybind11::tuple GetCrossFusedAttnBackwardWorkspaceSizes( - size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, - size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); - -void CrossFusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, - size_t opaque_len); - pybind11::tuple GetFusedAttnForwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); + float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training); void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); pybind11::tuple GetFusedAttnBackwardWorkspaceSizes( size_t input_batch, size_t bias_batch, size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, size_t head_dim, - float scaling_factor, float dropout_probability, - NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, DType dtype, bool is_training); + float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, + NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, bool is_training); void FusedAttnBackward(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); diff --git a/transformer_engine/jax/flax/transformer.py b/transformer_engine/jax/flax/transformer.py index 6979cffd90..fcf06aa128 100644 --- a/transformer_engine/jax/flax/transformer.py +++ b/transformer_engine/jax/flax/transformer.py @@ -26,7 +26,7 @@ from .module import LayerNorm, Softmax from ..fused_attn import AttnBiasType, AttnMaskType, QKVLayout from ..fused_attn import is_fused_attn_kernel_available, canonicalize_attn_mask_type -from ..fused_attn import self_fused_attn, cross_fused_attn, fused_attn +from ..fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn from ..softmax import SoftmaxType from ..sharding import num_of_devices from ..sharding import get_sharding_map_logic_axis_to_mesh_axis @@ -190,16 +190,19 @@ def __call__(self, def convert_to_softmax_type(attn_mask_type, mask): """Convert the attn_mask_type to SoftmaxType""" + # mask is ignored for no_mask and causal_mask + if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: + mask = None if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: - return SoftmaxType.SCALED_UPPER_TRIANG_MASKED + return SoftmaxType.SCALED_UPPER_TRIANG_MASKED, mask if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.PADDING_MASK]: if mask is not None: - return SoftmaxType.SCALED_MASKED - return SoftmaxType.SCALED - raise ValueError(f"Unsupported {attn_mask_type=}, " - "supported attn_mask_type = {'causal', 'padding'}") + return SoftmaxType.SCALED_MASKED, mask + return SoftmaxType.SCALED, mask + raise ValueError(f"Unsupported {attn_mask_type=}, supported attn_mask_type=" + "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}") - softmax_type = convert_to_softmax_type(self.attn_mask_type, mask) + softmax_type, mask = convert_to_softmax_type(self.attn_mask_type, mask) attn_weights = Softmax(softmax_type=softmax_type, scale_factor=fused_scale_factor)(attn_weights, mask, @@ -266,15 +269,15 @@ def __call__(self, qkv_packed = query if self.transpose_batch_sequence: qkv_packed = qkv_packed.transpose([1, 0, 2, 3, 4]) - x = self_fused_attn(qkv_packed, - bias, - mask, - seed, - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, - scaling_factor=scale_factor, - dropout_probability=self.attention_dropout, - is_training=not deterministic) + x = fused_attn_qkvpacked(qkv_packed, + bias, + mask, + seed, + attn_mask_type=self.attn_mask_type, + attn_bias_type=self.attn_bias_type, + scaling_factor=scale_factor, + dropout_probability=self.attention_dropout, + is_training=not deterministic) elif self.qkv_layout == QKVLayout.BSHD_BS2HD: """kvpacked format, treat query: query tensor, shape = [..., h, d] @@ -285,16 +288,16 @@ def __call__(self, if self.transpose_batch_sequence: query = query.transpose([1, 0, 2, 3]) kv_packed = kv_packed.transpose([1, 0, 2, 3, 4]) - x = cross_fused_attn(query, - kv_packed, - bias, - mask, - seed, - attn_mask_type=self.attn_mask_type, - attn_bias_type=self.attn_bias_type, - scaling_factor=scale_factor, - dropout_probability=self.attention_dropout, - is_training=not deterministic) + x = fused_attn_kvpacked(query, + kv_packed, + bias, + mask, + seed, + attn_mask_type=self.attn_mask_type, + attn_bias_type=self.attn_bias_type, + scaling_factor=scale_factor, + dropout_probability=self.attention_dropout, + is_training=not deterministic) elif self.qkv_layout == QKVLayout.BSHD_BSHD_BSHD: if self.transpose_batch_sequence: query = query.transpose([1, 0, 2, 3]) @@ -358,11 +361,27 @@ class DotProductAttention(nn.Module): # pylint: disable=too-few-public-method attention_dropout: float, default = 0.0 Dropout probability for the dropout op after the softmax. attn_mask_type: str, default = 'causal' - Type of the attention mask passed into softmax operation in the self attention. - Available options: {'no_mask', 'padding', 'causal', 'causal_padding'} - Introduced in v0.10.0. + This parameter specifies the type of attention mask to be applied during the softmax + operation. + Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'} + + Each described below: + + * no_mask: No attention mask is applied. This means the attention will consider the + full sequence without any restrictions. + * padding: Indicates the presence of padding at the end of each sequence. + Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the + :attr:`__call__` method to specify the padding positions. + * causal: An upper triangular mask is applied to the softmax inputs, + ensuring that the prediction for a certain position is only dependent on known outputs + from positions before it. + * causal_padding / padding_causal: A combination of both causal and padding masks. + Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. + + .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. + attn_bias_type: Optional[str], default = None - Type of the attention bias passed in the self attention. + Type of the attention bias passed in the attention. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. When default is present, the type is automatically decided by the MHA's bias parameter. Where it is :attr:`post_scale_bias` if there is bias. Otherwise :attr:`no_bias` is used. @@ -438,6 +457,7 @@ def __call__(self, mask: jax.numpy.ndarray, default = None Boolean tensor used to mask out the attention softmax input. :attr:`True` means to mask out the corresponding values. + Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'. bias: jax.numpy.ndarray, default = None A tensor used to shift attention softmax input. *: @@ -639,9 +659,25 @@ class MultiHeadAttention(nn.Module): # pylint: disable=too-few-public-methods attention_dropout: float, default = 0.0 Dropout probability for the dropout op after the softmax. attn_mask_type: str, default = 'causal' - Type of the attention mask passed into softmax operation in the attention. - Available options: {'no_mask', 'padding', 'causal', 'causal_padding'} - Introduced in v0.10.0. + This parameter specifies the type of attention mask to be applied during the softmax + operation. + Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'} + + Each described below: + + * no_mask: No attention mask is applied. This means the attention will consider the + full sequence without any restrictions. + * padding: Indicates the presence of padding at the end of each sequence. + Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the + :attr:`__call__` method to specify the padding positions. + * causal: An upper triangular mask is applied to the softmax inputs, + ensuring that the prediction for a certain position is only dependent on known outputs + from positions before it. + * causal_padding / padding_causal: A combination of both causal and padding masks. + Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. + + .. note:: :attr:`mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. + attn_bias_type: Optional[str], default = None Type of the attention bias passed in the attention. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. @@ -809,6 +845,7 @@ def __call__(self, mask: jax.numpy.ndarray, default = None Boolean tensor used to mask out the attention softmax input. :attr:`True` means mask out the corresponding values. + Ignored when :attr:`self.attn_mask_type` is either 'no_mask' or 'causal'. bias: jax.numpy.ndarray, default = None A tensor used to shift the attention softmax input. * @@ -1299,9 +1336,25 @@ class TransformerLayer(nn.Module): # pylint: disable=too-few-public-methods is added after self-attention.this can be used for structures like `T5` Transformer in conjunction with the TransformerLayerType.ENCODER option. self_attn_mask_type: str, default = 'causal' - Type of the attention mask passed into softmax operation in the self attention. - Available options: {'no_mask', 'padding', 'causal', 'causal_padding'} - Introduced in v0.10.0. + This parameter specifies the type of attention mask to be applied during the softmax + operation in the self attention. + Available options are {'no_mask', 'padding', 'causal', 'causal_padding', 'padding_causal'} + + Each described below: + + * no_mask: No attention mask is applied. This means the self attention will consider the + full sequence without any restrictions. + * padding: Indicates the presence of padding at the end of each sequence. + Users must provide a mask with the shape [batch, 1, max_seqlen_q, max_seqlen_kv] in the + :attr:`__call__` method to specify the padding positions. + * causal: An upper triangular mask is applied to the softmax inputs, + ensuring that the prediction for a certain position is only dependent on known outputs + from positions before it. + * causal_padding / padding_causal: A combination of both causal and padding masks. + Both 'causal_padding' and 'padding_causal' are acceptable and have the same effect. + + .. note:: :attr:`attention_mask` in :attr:`__call__` is ignored for 'no_mask' and 'causal'. + self_attn_bias_type: Optional[str], default = None Type of the attention bias passed into the self attention. Available options: {'no_bias', 'pre_scale_bias', 'post_scale_bias'}. @@ -1420,9 +1473,12 @@ def __call__(self, :attr:`layer_type=TransformerLayerType.DECODER`. attention_mask : jax.numpy.ndarray, default = None Boolean tensor used to mask out self-attention softmax input. + :attr:`True` means mask out the corresponding values. + Ignored when :attr:`self.self_attn_mask_type` is either 'no_mask' or 'causal'. encoder_decoder_mask: jax.numpy.ndarray, default = None Boolean tensor used to mask out cross-attention softmax input when :attr:`layer_type=TransformerLayerType.DECODER`. + :attr:`True` means mask out the corresponding values. deterministic: bool, default = False Disable dropout layers if set to True. decode: bool, default = False diff --git a/transformer_engine/jax/fused_attn.py b/transformer_engine/jax/fused_attn.py index 008f13ef4a..8b32163811 100644 --- a/transformer_engine/jax/fused_attn.py +++ b/transformer_engine/jax/fused_attn.py @@ -14,20 +14,29 @@ from transformer_engine_jax import NVTE_QKV_Layout from .cpp_extensions import FusedAttnHelper -from .cpp_extensions import cross_fused_attn_fwd, cross_fused_attn_bwd -from .cpp_extensions import self_fused_attn_fwd, self_fused_attn_bwd +from .cpp_extensions import fused_attn_fwd_kvpacked, fused_attn_bwd_kvpacked +from .cpp_extensions import fused_attn_fwd_qkvpacked, fused_attn_bwd_qkvpacked from .cpp_extensions import fused_attn_fwd, fused_attn_bwd class AttnBiasType(Enum): - """Attention Bias Type.""" + """ + NO_BIAS: Softmax is performed as softmax(scale * qk) + PRE_SCALE_BIAS: Softmax is performed as softmax(scale * (qk + bias)) + POST_SCALE_BIAS: Softmax is performed as softmax(scale * qk + bias) + """ NO_BIAS = NVTE_Bias_Type.NVTE_NO_BIAS PRE_SCALE_BIAS = NVTE_Bias_Type.NVTE_PRE_SCALE_BIAS POST_SCALE_BIAS = NVTE_Bias_Type.NVTE_POST_SCALE_BIAS class AttnMaskType(Enum): - """Attention Mask Type.""" + """ + NO_MASK: No attention mask is applied. + PADDING_MASK: Indicates the presence of paddings at the end of each sequence. + CAUSAL_MASK: An upper triangular mask is applied to the softmax inputs. + PADDING_CAUSAL_MASK: A combination of both causal and padding masks. + """ NO_MASK = NVTE_Mask_Type.NVTE_NO_MASK PADDING_MASK = NVTE_Mask_Type.NVTE_PADDING_MASK CAUSAL_MASK = NVTE_Mask_Type.NVTE_CAUSAL_MASK @@ -47,99 +56,105 @@ def canonicalize_attn_mask_type(attn_mask_type: str): The overhead between padding and non-padding version should be small. However, we will lease this limitation in the near feature. """ - if attn_mask_type in ['causal', 'padding_causal']: - return AttnMaskType.PADDING_CAUSAL_MASK - if attn_mask_type in ['no_mask', 'padding']: - return AttnMaskType.PADDING_MASK - raise ValueError(f"Unsupported {attn_mask_type=}, " - "supported attn_mask_type={'no_mask', 'padding', 'causal', 'padding_causal'}") - - -def is_fused_attn_kernel_available(q_type, kv_type, qkv_layout, attn_bias_type, attn_mask_type, - dropout_probability, num_heads_q, num_heads_kv, max_seqlen_q, - max_seqlen_kv, head_dim): + match attn_mask_type: + case 'no_mask': + return AttnMaskType.NO_MASK + case 'padding': + return AttnMaskType.PADDING_MASK + case 'causal': + return AttnMaskType.CAUSAL_MASK + case 'padding_causal' | 'causal_padding': + return AttnMaskType.PADDING_CAUSAL_MASK + raise ValueError(f"Unsupported {attn_mask_type=}, supported attn_mask_type=" + "{'no_mask', 'padding', 'causal', 'padding_causal', 'causal_padding'}") + + +def is_fused_attn_kernel_available(q_dtype, kv_dtype, qkv_layout, attn_bias_type, attn_mask_type, + dropout_probability, q_num_heads, kv_num_heads, q_max_seqlen, + kv_max_seqlen, head_dim): """ - To check whether the fused attention kernel is available + To check whether the fused attention kernel is supported """ - return FusedAttnHelper(q_type, kv_type, qkv_layout.value, attn_bias_type.value, - attn_mask_type.value, dropout_probability, num_heads_q, num_heads_kv, - max_seqlen_q, max_seqlen_kv, head_dim).is_fused_attn_kernel_available() + return FusedAttnHelper(q_dtype, kv_dtype, qkv_layout.value, attn_bias_type.value, + attn_mask_type.value, dropout_probability, q_num_heads, kv_num_heads, + q_max_seqlen, kv_max_seqlen, head_dim).is_fused_attn_kernel_available() -def self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray, - seed: jnp.ndarray | None, attn_bias_type: AttnBiasType, - attn_mask_type: AttnMaskType, scaling_factor: float, - dropout_probability: float, is_training: bool): +def fused_attn_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray, + seed: jnp.ndarray | None, attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, scaling_factor: float, + dropout_probability: float, is_training: bool): """ - Self fused attention wrapper + Fused attention with the qkvpacked inputs """ - output = _self_fused_attn(qkv, - bias, - mask, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + output = _fused_attn_qkvpacked(qkv, + bias, + mask, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) return output @partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8)) -def _self_fused_attn(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray, - seed: jnp.ndarray | None, attn_bias_type: AttnBiasType, - attn_mask_type: AttnMaskType, scaling_factor: float, - dropout_probability: float, is_training: bool): - - output, _ = _self_fused_attn_fwd_rule(qkv, bias, mask, seed, attn_bias_type, attn_mask_type, - scaling_factor, dropout_probability, is_training) +def _fused_attn_qkvpacked(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray, + seed: jnp.ndarray | None, attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, scaling_factor: float, + dropout_probability: float, is_training: bool): + + output, _ = _fused_attn_fwd_qkvpacked_rule(qkv, bias, mask, seed, attn_bias_type, + attn_mask_type, scaling_factor, dropout_probability, + is_training) return output -def _self_fused_attn_fwd_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None, - mask: jnp.ndarray, seed: jnp.ndarray | None, - attn_bias_type: AttnBiasType, - attn_mask_type: AttnMaskType, - scaling_factor: float, dropout_probability: float, - is_training: bool): - if mask is None: +def _fused_attn_fwd_qkvpacked_rule(qkv: jnp.ndarray, bias: jnp.ndarray | None, mask: jnp.ndarray, + seed: jnp.ndarray | None, attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, scaling_factor: float, + dropout_probability: float, is_training: bool): + if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: batch, seqlen, *_ = qkv.shape actual_seqlen = jnp.full((batch,), seqlen, dtype=jnp.int32) else: + assert mask is not None mask = jnp.logical_not(mask) actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) - output, softmax_aux, rng_state = self_fused_attn_fwd(qkv, - bias, - actual_seqlen, - seed, - attn_bias_type=attn_bias_type.value, - attn_mask_type=attn_mask_type.value, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + output, softmax_aux, rng_state = fused_attn_fwd_qkvpacked( + qkv, + bias, + actual_seqlen, + seed, + attn_bias_type=attn_bias_type.value, + attn_mask_type=attn_mask_type.value, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) output = checkpoint_name(output, 'context') softmax_aux = checkpoint_name(softmax_aux, 'context') rng_state = checkpoint_name(rng_state, 'context') return output, (qkv, bias, softmax_aux, rng_state, output, actual_seqlen) -def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, - is_training, ctx, dz): +def _fused_attn_bwd_qkvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training, ctx, dz): qkv, bias, softmax_aux, rng_state, output, actual_seqlen = ctx - grad_qkv, grad_bias = self_fused_attn_bwd(qkv, - bias, - softmax_aux, - rng_state, - output, - dz, - actual_seqlen, - attn_bias_type=attn_bias_type.value, - attn_mask_type=attn_mask_type.value, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + grad_qkv, grad_bias = fused_attn_bwd_qkvpacked(qkv, + bias, + softmax_aux, + rng_state, + output, + dz, + actual_seqlen, + attn_bias_type=attn_bias_type.value, + attn_mask_type=attn_mask_type.value, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) if attn_bias_type == AttnBiasType.NO_BIAS: grad_bias = None @@ -147,91 +162,96 @@ def _self_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dr return grad_qkv, grad_bias, None, None -_self_fused_attn.defvjp(_self_fused_attn_fwd_rule, _self_fused_attn_bwd_rule) +_fused_attn_qkvpacked.defvjp(_fused_attn_fwd_qkvpacked_rule, _fused_attn_bwd_qkvpacked_rule) -def cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, - seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, - scaling_factor: float, dropout_probability: float, is_training: bool): +def fused_attn_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, + seed: jnp.ndarray, attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, scaling_factor: float, + dropout_probability: float, is_training: bool): """ - Cross multi-head attention wrapper + Fused attention with the kvpacked inputs """ - output = _cross_fused_attn(q, - kv, - bias, - mask, - seed, - attn_bias_type=attn_bias_type, - attn_mask_type=attn_mask_type, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + output = _fused_attn_kvpacked(q, + kv, + bias, + mask, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) return output @partial(jax.custom_vjp, nondiff_argnums=(5, 6, 7, 8, 9)) -def _cross_fused_attn(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, - seed: jnp.ndarray, attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, - scaling_factor: float, dropout_probability: float, is_training: bool): - - output, _ = _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type, - scaling_factor, dropout_probability, is_training) +def _fused_attn_kvpacked(q: jnp.ndarray, kv: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, + seed: jnp.ndarray, attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, scaling_factor: float, + dropout_probability: float, is_training: bool): + + output, _ = _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type, + attn_mask_type, scaling_factor, dropout_probability, + is_training) return output -def _cross_fused_attn_fwd_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type, - scaling_factor, dropout_probability, is_training): - if mask is None: +def _fused_attn_fwd_kvpacked_rule(q, kv, bias, mask, seed, attn_bias_type, attn_mask_type, + scaling_factor, dropout_probability, is_training): + if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: batch, s_q, *_ = q.shape s_kv = kv.shape[1] q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32) kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32) else: + assert mask is not None mask = jnp.logical_not(mask) q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) - if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: + if attn_mask_type == AttnMaskType.PADDING_MASK: kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,) else: # When mask is causal, the actual seqlen is not the last row, use max to find it kv_actual_seqlen = jnp.max(jnp.sum(mask, axis=-1, dtype=jnp.int32), axis=(-1, -2)) - output, softmax_aux, rng_state = cross_fused_attn_fwd(q, - kv, - bias, - q_actual_seqlen, - kv_actual_seqlen, - seed, - attn_bias_type=attn_bias_type.value, - attn_mask_type=attn_mask_type.value, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + output, softmax_aux, rng_state = fused_attn_fwd_kvpacked( + q, + kv, + bias, + q_actual_seqlen, + kv_actual_seqlen, + seed, + attn_bias_type=attn_bias_type.value, + attn_mask_type=attn_mask_type.value, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) output = checkpoint_name(output, 'context') softmax_aux = checkpoint_name(softmax_aux, 'context') rng_state = checkpoint_name(rng_state, 'context') return output, (q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen) -def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, - is_training, ctx, dz): +def _fused_attn_bwd_kvpacked_rule(attn_bias_type, attn_mask_type, scaling_factor, + dropout_probability, is_training, ctx, dz): q, kv, bias, softmax_aux, rng_state, output, q_actual_seqlen, kv_actual_seqlen = ctx - grad_q, grad_kv, grad_bias = cross_fused_attn_bwd(q, - kv, - bias, - softmax_aux, - rng_state, - output, - dz, - q_actual_seqlen, - kv_actual_seqlen, - attn_bias_type=attn_bias_type.value, - attn_mask_type=attn_mask_type.value, - scaling_factor=scaling_factor, - dropout_probability=dropout_probability, - is_training=is_training) + grad_q, grad_kv, grad_bias = fused_attn_bwd_kvpacked(q, + kv, + bias, + softmax_aux, + rng_state, + output, + dz, + q_actual_seqlen, + kv_actual_seqlen, + attn_bias_type=attn_bias_type.value, + attn_mask_type=attn_mask_type.value, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training) if attn_bias_type == AttnBiasType.NO_BIAS: grad_bias = None @@ -239,7 +259,7 @@ def _cross_fused_attn_bwd_rule(attn_bias_type, attn_mask_type, scaling_factor, d return grad_q, grad_kv, grad_bias, None, None -_cross_fused_attn.defvjp(_cross_fused_attn_fwd_rule, _cross_fused_attn_bwd_rule) +_fused_attn_kvpacked.defvjp(_fused_attn_fwd_kvpacked_rule, _fused_attn_bwd_kvpacked_rule) def fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarray, mask: jnp.ndarray, @@ -277,15 +297,16 @@ def _fused_attn(q: jnp.ndarray, k: jnp.ndarray, v: jnp.ndarray, bias: jnp.ndarra def _fused_attn_fwd_rule(q, k, v, bias, mask, seed, attn_bias_type, attn_mask_type, scaling_factor, dropout_probability, is_training): - if mask is None: + if attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: batch, s_q, *_ = q.shape s_kv = k.shape[1] q_actual_seqlen = jnp.full((batch,), s_q, dtype=jnp.int32) kv_actual_seqlen = jnp.full((batch,), s_kv, dtype=jnp.int32) else: + assert mask is not None mask = jnp.logical_not(mask) q_actual_seqlen = jnp.sum(mask, axis=-2, dtype=jnp.int32)[..., 0, 0] # shape = (b,) - if attn_mask_type not in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK]: + if attn_mask_type == AttnMaskType.PADDING_MASK: kv_actual_seqlen = jnp.sum(mask, axis=-1, dtype=jnp.int32)[..., 0, 0] # shape = (b,) else: # When mask is causal, the actual seqlen is not the last row, use max to find it diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 3711d9898f..57c7e75e9e 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -107,13 +107,18 @@ def forward( if ub_overlap_ag: tp_world_size = get_distributed_world_size(tp_group) - if tp_world_size == 1 or (not is_grad_enabled) or return_layernorm_output: + if tp_world_size == 1 or (not is_grad_enabled): ub_overlap_ag = False if ub_overlap_ag: dim_size = list(inputmat.size()) dim_size[0] = dim_size[0] * tp_world_size ub_obj_lnout = get_ub(ub_name+"_fprop") - ln_out = ub_obj_lnout.get_ubuf_output(0) + if return_layernorm_output: + # First prepare LN output in higher precision, + # which will be later copied to a FP8 UB + ln_out = torch.empty_like(inputmat) + else: + ln_out = ub_obj_lnout.get_ubuf_output(0) else: ln_out_dtype = torch.uint8 if (fp8 and not return_layernorm_output) else inputmat.dtype ln_out = torch.empty_like(inputmat, dtype=ln_out_dtype) @@ -136,7 +141,8 @@ def forward( ln_out_gathered = False if ub_overlap_ag: ln_out_total = ub_obj_lnout.get_ubuf_output(1) - ln_out = torch.empty_like(ln_out) + if not return_layernorm_output: + ln_out = torch.empty_like(ln_out) if ub_obj_lnout.is_atomic_gemm(): ub_algo = tex.UbufOverlapAlgo.ATOMIC_GEMM_AG_P2P else: @@ -153,12 +159,22 @@ def forward( if return_layernorm_output: ln_out_return = ln_out_total if return_layernorm_output_gathered else ln_out if fp8: - ln_out = tex.cast_to_fp8( - ln_out, - fp8_meta["scaling_fwd"], - tex.FP8FwdTensors.GEMM1_INPUT, - fp8_dtype_forward, - ) + if ub_overlap_ag: + ln_out_fp8 = ub_obj_lnout.get_ubuf_output(0) + tex.cast_to_fp8( + ln_out, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + out=ln_out_fp8) + ln_out = ln_out_fp8 + else: + ln_out = tex.cast_to_fp8( + ln_out, + fp8_meta["scaling_fwd"], + tex.FP8FwdTensors.GEMM1_INPUT, + fp8_dtype_forward, + ) if fp8: bias_dtype = (