diff --git a/tests/jax/conftest.py b/tests/jax/conftest.py index ccb6690a87..5bb86c6081 100644 --- a/tests/jax/conftest.py +++ b/tests/jax/conftest.py @@ -20,7 +20,7 @@ def clear_live_arrays(): @pytest.fixture(autouse=True, scope="module") -def enable_fused_attn(): +def enable_fused_attn_after_hopper(): """ Enable fused attn for hopper+ arch. Fused attn kernels on pre-hopper arch are not deterministic. diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index e194a228d2..1538062975 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -20,7 +20,6 @@ from utils import ( make_causal_mask, make_self_mask, - assert_tree_like_allclose, assert_allclose, print_debug_tensor_stats, ) @@ -32,7 +31,6 @@ AttnMaskType, QKVLayout, QKVFormat, - get_qkv_format, reorder_causal_load_balancing, inverse_reorder_causal_load_balancing, CPStrategy, @@ -421,7 +419,7 @@ def impl_test_contex_parallel_attn( dropout_prob = 0.0 is_training = True dp_size, cp_size, tp_size = mesh_shape - qkv_format = get_qkv_format(qkv_layout) + qkv_format = qkv_layout.get_qkv_format() batch, seqlen, num_head, hidden = data_shape @@ -503,7 +501,7 @@ def grad_func(func, *args, **kwargs): # Gradient is small, use a gradient multiplier to amplify the gradient _, max_seq_len, num_heads, _ = data_shape gradient_multiplier = max_seq_len * num_heads - if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]: + if attn_mask_type.is_causal(): gradient_multiplier /= 10 ret_valid = func(*args, **kwargs) return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype) diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index af05538ef5..759ea893ef 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -28,7 +28,6 @@ QKVFormat, fused_attn, fused_attn_thd, - get_qkv_format, make_swa_mask, ) from transformer_engine.jax.cpp_extensions import FusedAttnHelper @@ -50,6 +49,7 @@ def init(): yield +@partial(jax.jit, static_argnums=(5, 6, 7, 9)) def general_dot_product_attention( query: ArrayLike, key: ArrayLike, @@ -102,29 +102,36 @@ def general_dot_product_attention( return context -def is_causal_mask(mask: AttnMaskType): - """ - Check if the mask is a causal mask - """ - return mask in [AttnMaskType.CAUSAL_MASK, AttnMaskType.PADDING_CAUSAL_MASK] - - -def make_causal_mask(q_tokens: ArrayLike, kv_tokens: ArrayLike) -> Array: +@jax.jit +def make_causal_mask( + segment_ids_q: ArrayLike, + segment_ids_kv: ArrayLike, + segment_pos_q: ArrayLike = None, + segment_pos_kv: ArrayLike = None, +) -> Array: """ Create inverse padded causal mask where `True` means allowing the corresponding position to participate in attention and `False` means masking out that position. + If segment_pos is not provided, aragne of the segment_ids will be applied. """ - q_idxs = jnp.broadcast_to(jnp.arange(q_tokens.shape[-1], dtype=jnp.int32), q_tokens.shape) - kv_idxs = jnp.broadcast_to(jnp.arange(kv_tokens.shape[-1], dtype=jnp.int32), kv_tokens.shape) - inv_causal_mask = make_attention_mask(q_idxs, kv_idxs, jnp.greater_equal) + if segment_pos_q is None: + segment_pos_q = jnp.broadcast_to( + jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape + ) + if segment_pos_kv is None: + segment_pos_kv = jnp.broadcast_to( + jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape + ) + inv_causal_mask = make_attention_mask(segment_pos_q, segment_pos_kv, jnp.greater_equal) return inv_causal_mask +@partial(jax.jit, static_argnums=(4, 5)) def make_mask( - q_token: ArrayLike, - kv_token: ArrayLike, - segment_pad_q: ArrayLike, - segment_pad_kv: ArrayLike, + segment_ids_q: ArrayLike, + segment_ids_kv: ArrayLike, + segment_pos_q: ArrayLike, + segment_pos_kv: ArrayLike, attn_mask_type: AttnMaskType, window_size: Optional[Tuple[int, int]] = None, ) -> Array: @@ -132,18 +139,31 @@ def make_mask( Create attention mask based on mask type. A `True` value in the mask means masking out the corresponding position and a `False` value means allowing that position to participate in attention. + + - segment_ids should start with 1, and using 0s for the paddings. + Expected that each segment starts without paddings. + - segment_pos marks the token position in the segments. + + A example pair of segments_ids and segment_pos: + segment_ids: [1, 1, 1, 0, 2, 2, 2, 3, 3, 3, 4, 0, 0, 5, 5, 5] + segment_pos: [0, 1, 2, 3, 0, 1, 2, 0, 1, 2, 0, 1, 2, 0, 1, 2] """ inv_mask = make_attention_mask( - q_token, kv_token, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0)) + segment_ids_q, segment_ids_kv, lambda x, y: (jnp.logical_and(jnp.equal(x, y), x != 0)) ) - if is_causal_mask(attn_mask_type): - inv_causal_mask = make_causal_mask(q_token, kv_token) - inv_mask = combine_masks(inv_causal_mask, inv_mask) - if segment_pad_q is not None and segment_pad_kv is not None: - inv_pad_mask = make_attention_mask( - segment_pad_q, segment_pad_kv, lambda x, y: jnp.logical_and(x != 1, y != 1) + if attn_mask_type.is_causal(): + if segment_pos_q is None: + segment_pos_q = jnp.broadcast_to( + jnp.arange(segment_ids_q.shape[-1], dtype=jnp.int32), segment_ids_q.shape + ) + if segment_pos_kv is None: + segment_pos_kv = jnp.broadcast_to( + jnp.arange(segment_ids_kv.shape[-1], dtype=jnp.int32), segment_ids_kv.shape + ) + inv_causal_mask = make_attention_mask( + segment_pos_q, segment_pos_kv, lambda x, y: jnp.greater_equal(x, y) ) - inv_mask = combine_masks(inv_pad_mask, inv_mask) + inv_mask = combine_masks(inv_causal_mask, inv_mask) if window_size is not None: max_seqlen_q = inv_mask.shape[-2] @@ -157,7 +177,8 @@ def make_mask( return mask -def get_seqlens_and_offsets(segment_ids, segment_pad): +@jax.jit +def get_seqlens_and_offsets(segment_ids): batch, max_seqlen = segment_ids.shape bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_seqlen)) seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32)) @@ -165,7 +186,7 @@ def get_seqlens_and_offsets(segment_ids, segment_pad): def _find_offsets(x): same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0) - first_column = jnp.ones((x.shape[0], 1), dtype=bool) + first_column = x[..., :1] != 0 same_as_previous = jnp.hstack((first_column, same_as_previous)) return jax.vmap(partial(jnp.argwhere, size=x.shape[1], fill_value=-1))( same_as_previous @@ -173,13 +194,9 @@ def _find_offsets(x): offsets = _find_offsets(segment_ids) offsets = jnp.insert(offsets, -1, values=-1, axis=-1) - if segment_pad is not None: - segment_id_with_paddings = jnp.where(segment_pad, 0, segment_ids) - padding_aware_seqlen = bincount_vmap(segment_id_with_paddings) - output = jnp.insert(padding_aware_seqlen[..., 1:], -1, values=0, axis=-1) - else: - output = jnp.insert(seqlens, -1, values=0, axis=-1) - return output, offsets + seqlens = jnp.insert(seqlens, -1, values=0, axis=-1) + seqlens = jnp.where(seqlens, seqlens, -1) + return seqlens, offsets @jax.jit @@ -200,8 +217,8 @@ def jax_dpa(query, key, value, bias, mask, dropout_rng, **kwargs): query, key, value, - bias=bias, - mask=mask, + bias, + mask, deterministic=not kwargs["is_training"], scale_factor=kwargs["scaling_factor"], dropout_rate=kwargs["dropout_probability"], @@ -228,7 +245,6 @@ def customcall_fused_dpa( TE customcall dot product attention implementation """ qkv_layout = kwargs["qkv_layout"] - is_thd = get_qkv_format(qkv_layout) == QKVFormat.THD match qkv_layout: case QKVLayout.BS3HD | QKVLayout.T3HD: query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value]) @@ -242,7 +258,7 @@ def customcall_fused_dpa( qkv_args = (query, key, value) case _: raise ValueError(f"Unsupported {qkv_layout=}") - if not is_thd: + if not qkv_layout.is_thd(): kwargs.pop("max_segments_per_seq") return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype) return fused_attn_thd( @@ -262,10 +278,10 @@ class BiasShape(Enum): Enum class to represent the different bias shapes used in the fused attention. """ - BIAS_1HSS = "1HSS" - BIAS_B1SS = "B1SS" - BIAS_BHSS = "BHSS" - BIAS_11SS = "11SS" + _1HSS = "1HSS" + _B1SS = "B1SS" + _BHSS = "BHSS" + _11SS = "11SS" @dataclass @@ -300,18 +316,12 @@ def _get_max_segments_per_sequence(self): def _check_configs(self): # TODO(rewang): probably adds this in is_fused_attn_available - if get_qkv_format(self.qkv_layout) == QKVFormat.THD and not self.attn_mask_type in [ - AttnMaskType.PADDING_MASK, - AttnMaskType.PADDING_CAUSAL_MASK, - ]: + if self.qkv_layout.is_thd() and not self.attn_mask_type.is_padding(): pytest.skip("THD format requires padding masks.") - qkv_format = get_qkv_format(self.qkv_layout) - if self.qkv_layout == QKVLayout.BS3HD or qkv_format == QKVFormat.THD: + if self.qkv_layout.is_qkvpacked(): if self.max_seqlen_q != self.max_seqlen_kv: pytest.skip(f"{self.qkv_layout} requires max_seqlen_q == max_seqlen_kv") - - if self.qkv_layout == QKVLayout.BS3HD or self.qkv_layout == QKVLayout.T3HD: if self.num_heads_q != self.num_heads_kv: pytest.skip(f"{self.qkv_layout} requires num_heads_q == num_heads_kv") @@ -339,15 +349,11 @@ def _check_configs(self): if ( self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS - and self.bias_shape != BiasShape.BIAS_1HSS + and self.bias_shape != BiasShape._1HSS ): - if self.attn_mask_type not in [ - AttnMaskType.NO_MASK, - AttnMaskType.CAUSAL_MASK, - ]: + if self.attn_mask_type.is_padding(): pytest.skip( - "B1SS, BHSS and 11SS bias shapes are only supported for " - "AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK." + "B1SS, BHSS and 11SS bias shapes are only supported for non-padding mask" ) elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen: pytest.skip( @@ -370,18 +376,18 @@ def _setup_inputs(self): if self.attn_bias_type == AttnBiasType.NO_BIAS: bias_shape = None - elif self.bias_shape == BiasShape.BIAS_1HSS: + elif self.bias_shape == BiasShape._1HSS: bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv) - elif self.bias_shape == BiasShape.BIAS_B1SS: + elif self.bias_shape == BiasShape._B1SS: bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv) - elif self.bias_shape == BiasShape.BIAS_BHSS: + elif self.bias_shape == BiasShape._BHSS: bias_shape = ( self.batch_size, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv, ) - elif self.bias_shape == BiasShape.BIAS_11SS: + elif self.bias_shape == BiasShape._11SS: bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv) else: pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!") @@ -391,7 +397,7 @@ def _setup_inputs(self): self.v = jax.random.uniform(v_key, v_shape, self.dtype, -1.0) if self.attn_bias_type != AttnBiasType.NO_BIAS: - if self.bias_shape == BiasShape.BIAS_1HSS: + if self.bias_shape == BiasShape._1HSS: self.bias = jax.random.uniform(bias_key, bias_shape, self.dtype, -1.0) else: # [b, 1, s, s], [b, h, s, s] and [1, 1, s, s] bias shapes are workarounds for @@ -408,10 +414,10 @@ def _setup_inputs(self): else: self.bias = None - if self.attn_mask_type in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]: - pad_ratio = 0.0 - else: + if self.attn_mask_type.is_padding(): pad_ratio = 0.3 + else: + pad_ratio = 0.0 def gen_valid(bs, max_seqlen, pad_ratio): pad_len = int(max_seqlen * pad_ratio) @@ -425,6 +431,8 @@ def generate_random_segment_ids( rng = np.random.default_rng(seed=seed) # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad segment_ids = np.zeros((batch_size, sequence_length), dtype=int) + segment_pos = np.zeros((batch_size, sequence_length), dtype=int) + # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0] # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad segment_pad = np.zeros((batch_size, sequence_length), dtype=int) @@ -440,58 +448,62 @@ def generate_random_segment_ids( break segment_end = current_pos + segment_size segment_ids[i, current_pos:segment_end] = segment_id + segment_pos[i, current_pos:segment_end] = np.arange(segment_size) if with_segment_pad: num_valid = rng.integers(1, segment_size + 1) segment_pad[i, current_pos + num_valid : segment_end] = 1 current_pos = segment_end segment_id += 1 segment_pad[i, current_pos:sequence_length] = 1 - return segment_ids, segment_pad - if get_qkv_format(self.qkv_layout) == QKVFormat.THD: + segment_ids, segment_pos, segment_pad = map( + jnp.asarray, [segment_ids, segment_pos, segment_pad] + ) + segment_ids = jnp.where(segment_pad, 0, segment_ids) + return segment_ids, segment_pos, segment_pad + + if self.qkv_layout.is_thd(): self.num_segments_per_seq = 2 - self.token_q, self.segment_pad_q = generate_random_segment_ids( + self.segment_ids_q, self.segment_pos_q, self.pad_q = generate_random_segment_ids( self.batch_size, self.max_seqlen_q, self.num_segments_per_seq, seed=42 ) - # TODO(rewang): Check if qkvpacked supported different q/kv - # TODO(rewang): Causal with different q/kv segment_id fails - if self.qkv_layout == QKVLayout.T3HD or is_causal_mask(self.attn_mask_type): - self.token_kv = self.token_q - self.segment_pad_kv = self.segment_pad_q + if self.qkv_layout == QKVLayout.T3HD: + self.segment_ids_kv = self.segment_ids_q + self.segment_pos_kv = self.segment_pos_q + self.pad_kv = self.pad_q else: - self.token_kv, self.segment_pad_kv = generate_random_segment_ids( + self.segment_ids_kv, self.segment_pos_kv, self.pad_kv = generate_random_segment_ids( self.batch_size, self.max_seqlen_kv, self.num_segments_per_seq, seed=2024, ) - self.pad_q = self.segment_pad_q - self.pad_kv = self.segment_pad_kv + self.seqlens_q, self.offsets_q = get_seqlens_and_offsets(self.segment_ids_q) + self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets(self.segment_ids_kv) else: self.num_segments_per_seq = 1 - self.token_q, self.pad_q = gen_valid(self.batch_size, self.max_seqlen_q, pad_ratio) - self.token_kv, self.pad_kv = gen_valid(self.batch_size, self.max_seqlen_kv, pad_ratio) - self.segment_pad_q = self.segment_pad_kv = None + self.segment_ids_q, self.pad_q = gen_valid( + self.batch_size, self.max_seqlen_q, pad_ratio + ) + self.segment_ids_kv, self.pad_kv = gen_valid( + self.batch_size, self.max_seqlen_kv, pad_ratio + ) + self.segment_pos_q = self.segment_pos_kv = None + self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None + # For reference code self.mask = make_mask( - self.token_q, - self.token_kv, - self.segment_pad_q, - self.segment_pad_kv, + self.segment_ids_q, + self.segment_ids_kv, + self.segment_pos_q, + self.segment_pos_kv, self.attn_mask_type, self.window_size, ) - if get_qkv_format(self.qkv_layout) == QKVFormat.THD: - self.seqlens_q, self.offsets_q = get_seqlens_and_offsets( - self.token_q, self.segment_pad_q - ) - self.seqlens_kv, self.offsets_kv = get_seqlens_and_offsets( - self.token_kv, self.segment_pad_kv - ) + if self.qkv_layout.is_thd(): self.mask_for_customcall = None # THD format doesn't support mask else: - self.seqlens_q = self.seqlens_kv = self.offsets_q = self.offsets_kv = None self.mask_for_customcall = self.mask self.dropout_rng = dropout_key if self.dropout_prob > 0 else None @@ -547,13 +559,11 @@ def test_backward(self): """ self._setup_inputs() - if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape != BiasShape.BIAS_1HSS: - pytest.skip("Bias gradient calculation is only supported for 1HSS bias shape.") def grad_func(func, *args, **kwargs): # Gradient is small, use a gradient multiplier to amplify the gradient gradient_multiplier = self.max_seqlen_q * self.num_heads_q - if is_causal_mask(self.attn_mask_type): + if self.attn_mask_type.is_causal(): gradient_multiplier /= 10 # Keep only valid result for the gradient ret_valid = jnp.where( @@ -586,7 +596,7 @@ def grad_func(func, *args, **kwargs): } # We can compute dBias only for the [1, h, s, s] layout - arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape.BIAS_1HSS else (0, 1, 2) + arg_nums = (0, 1, 2, 3) if self.bias_shape == BiasShape._1HSS else (0, 1, 2) # Use FP16/BF16 to sum the results may cause overflow, use FP32 for the summation jitted_primitive = jit( @@ -629,7 +639,7 @@ def check_dqkv(primitive, reference, pad): check_dqkv(primitive_dk, reference_dk, self.pad_kv) check_dqkv(primitive_dv, reference_dv, self.pad_kv) - if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape.BIAS_1HSS: + if self.attn_bias_type != AttnBiasType.NO_BIAS and self.bias_shape == BiasShape._1HSS: primitive_dbias = primitive_dgrad[3] reference_dbias = reference_dgrad[3] @@ -658,16 +668,6 @@ def check_dqkv(primitive, reference, pad): ) -@pytest.mark.parametrize( - "attn_bias_type, bias_shape", - [ - pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id="POST_SCALE_BIAS-1HSS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id="POST_SCALE_BIAS-B1SS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id="POST_SCALE_BIAS-BHSS"), - pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id="POST_SCALE_BIAS-11SS"), - ], -) @pytest.mark.parametrize( "attn_mask_type", [ @@ -736,6 +736,16 @@ class TestFusedAttn: pytest.param(False, id="INFERENCE"), ], ) + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._B1SS, id="POST_SCALE_BIAS-B1SS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._BHSS, id="POST_SCALE_BIAS-BHSS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._11SS, id="POST_SCALE_BIAS-11SS"), + ], + ) def _test_forward( b, s_q, @@ -779,6 +789,13 @@ def _test_forward( runner.test_forward() @staticmethod + @pytest.mark.parametrize( + "attn_bias_type, bias_shape", + [ + pytest.param(AttnBiasType.NO_BIAS, None, id="NO_BIAS"), + pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape._1HSS, id="POST_SCALE_BIAS-1HSS"), + ], + ) def test_backward( b, s_q, diff --git a/tests/jax/utils.py b/tests/jax/utils.py index 78a6225e1f..242bafa5e2 100644 --- a/tests/jax/utils.py +++ b/tests/jax/utils.py @@ -19,7 +19,11 @@ from jax import nn as jax_nn from jax import random as jax_random -from transformer_engine.jax.attention import AttnMaskType, make_swa_mask +from transformer_engine.jax.attention import ( + AttnMaskType, + canonicalize_attn_mask_type, + make_swa_mask, +) from transformer_engine.jax.fp8 import DType as TEDType PRNGKey = Any @@ -913,15 +917,7 @@ def apply_swa_mask( window_size: Tuple[int, int] = (-1, -1), ) -> Array: """Apply the sliding window mask to a given mask""" - mask_map = { - "no_mask": AttnMaskType.NO_MASK, - "padding": AttnMaskType.PADDING_MASK, - "causal": AttnMaskType.CAUSAL_MASK, - "padding_causal": AttnMaskType.PADDING_CAUSAL_MASK, - "causal_bottom_right": AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, - "padding_causal_bottom_right": AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, - } - _attn_mask_type = mask_map.get(attn_mask_type, None) + _attn_mask_type = canonicalize_attn_mask_type(attn_mask_type) assert _attn_mask_type is not None max_seqlen_q = original_mask.shape[-2] max_seqlen_kv = original_mask.shape[-1] diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 3ecc9bcd75..53451b6a78 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -46,6 +46,42 @@ class AttnMaskType(Enum): CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_CAUSAL_BOTTOM_RIGHT_MASK PADDING_CAUSAL_BOTTOM_RIGHT_MASK = NVTE_Mask_Type.NVTE_PADDING_CAUSAL_BOTTOM_RIGHT_MASK + def is_causal(self): + """Returns True if the mask is a causal mask""" + return self in [ + AttnMaskType.CAUSAL_MASK, + AttnMaskType.PADDING_CAUSAL_MASK, + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + + def is_padding(self): + """Returns True if the mask includes padding""" + return self in [ + AttnMaskType.PADDING_MASK, + AttnMaskType.PADDING_CAUSAL_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + + def is_bottom_right(self): + """Returns True if the causal mask is calculated from the bottom-right section""" + return self in [ + AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, + AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, + ] + + +class QKVFormat(Enum): + """ + SBHD: q,k,v memory layout with [s, b, ..., h, d] + BSHD: q,k,v memory layout with [b, s, ..., h, d] + THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence. + """ + + SBHD = NVTE_QKV_Format.NVTE_SBHD + BSHD = NVTE_QKV_Format.NVTE_BSHD + THD = NVTE_QKV_Format.NVTE_THD + class QKVLayout(Enum): """ @@ -66,17 +102,35 @@ class QKVLayout(Enum): THD_T2HD = NVTE_QKV_Layout.NVTE_THD_T2HD THD_THD_THD = NVTE_QKV_Layout.NVTE_THD_THD_THD - -class QKVFormat(Enum): - """ - SBHD: q,k,v memory layout with [s, b, ..., h, d] - BSHD: q,k,v memory layout with [b, s, ..., h, d] - THD: q,k,v memory layout is same as BSHD, but allow multiple segments packed in a sequence. - """ - - SBHD = NVTE_QKV_Format.NVTE_SBHD - BSHD = NVTE_QKV_Format.NVTE_BSHD - THD = NVTE_QKV_Format.NVTE_THD + def get_qkv_format(self): + """ + Return the corresponding qkv_format (BSHD, SBHD, THD) + """ + return QKVFormat(nvte_get_qkv_format(self.value)) + + def is_qkvpacked(self): + """ + Return True if the query, key, value is packed + """ + return self in [QKVLayout.BS3HD, QKVLayout.T3HD] + + def is_kvpacked(self): + """ + Return True if the key, value is packed + """ + return self in [QKVLayout.BSHD_BS2HD, QKVLayout.THD_T2HD] + + def is_separate(self): + """ + Return True if the query, key, value are three separate tensors + """ + return self in [QKVLayout.BSHD_BSHD_BSHD, QKVLayout.THD_THD_THD] + + def is_thd(self): + """ + Return True if the layout belongs to THD + """ + return self in [QKVLayout.T3HD, QKVLayout.THD_T2HD, QKVLayout.THD_THD_THD] class CPStrategy(Enum): @@ -92,13 +146,6 @@ class CPStrategy(Enum): RING = 2 -def get_qkv_format(qkv_layout): - """ - Get qkv_format from qkv_layout - """ - return QKVFormat(nvte_get_qkv_format(qkv_layout.value)) - - def make_swa_mask( max_seqlen_q: int, max_seqlen_kv: int, @@ -136,12 +183,8 @@ def make_swa_mask( swa_mask = jnp.ones((max_seqlen_q, max_seqlen_kv), dtype=dtype) if window_size is None: return swa_mask - bottom_right_masks = [ - AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, - AttnMaskType.PADDING_CAUSAL_BOTTOM_RIGHT_MASK, - ] left_window, right_window = window_size - if attn_mask_type in bottom_right_masks: + if attn_mask_type.is_bottom_right(): if left_window < 0: left_window = max_seqlen_kv if right_window < 0: @@ -310,7 +353,7 @@ def fused_attn( (jnp.ndarray): The output tensor from the fused attention. """ assert ( - get_qkv_format(qkv_layout) != QKVFormat.THD + not qkv_layout.is_thd() ), "Please use transformer_engine.jax.attention.fused_attn_thd for THD format." # Check inputs qkv @@ -327,11 +370,7 @@ def fused_attn( ), f"qkv=(query, key, value) is expected with {qkv_layout=} but got {qkv=}" # convert the mask to seqlens, mask doesn't support ragged offsets - if attn_mask_type in [ - AttnMaskType.NO_MASK, - AttnMaskType.CAUSAL_MASK, - AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK, - ]: + if not attn_mask_type.is_padding(): batch, q_max_seqlen, kv_max_seqlen = _obtain_batch_and_max_seqlen(qkv, qkv_layout) q_seq_lens = jnp.full((batch,), q_max_seqlen, dtype=jnp.int32) kv_seq_lens = jnp.full((batch,), kv_max_seqlen, dtype=jnp.int32) @@ -448,7 +487,7 @@ def fused_attn_thd( QKVLayout.T3HD, 0.125, 0, True, 3) """ assert ( - get_qkv_format(qkv_layout) == QKVFormat.THD + qkv_layout.is_thd() ), "Please use transformer_engine.jax.attention.fused_attn for non-THD format." # Check inputs qkv diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 6591861057..f3dfca21ef 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -3,7 +3,7 @@ # See LICENSE for license information. """JAX/TE custom ops for attention""" from dataclasses import dataclass -from functools import partial, reduce, cache +from functools import partial, reduce import operator import os from typing import Optional, Tuple @@ -133,7 +133,6 @@ def get_fused_attn_backend(self): ) @staticmethod - @cache def is_non_deterministic_allowed(): """Check if non-deterministic kernels are allowed""" return bool(int(os.getenv("NVTE_ALLOW_NONDETERMINISTIC_ALGO", "1")))