diff --git a/tests/jax/test_distributed_fused_attn.py b/tests/jax/test_distributed_fused_attn.py index 2e15dd4d5d..d7e015dbf7 100644 --- a/tests/jax/test_distributed_fused_attn.py +++ b/tests/jax/test_distributed_fused_attn.py @@ -2,31 +2,18 @@ # # See LICENSE for license information. -import pytest -from functools import partial - import jax import jax.numpy as jnp import numpy as np -from flax.linen import dot_product_attention from jax import random -from jax.sharding import Mesh, NamedSharding, PartitionSpec from distributed_test_base import ( generate_configs, generate_context_parallel_configs, generate_collectives_count, - compare_ops, -) -from utils import ( - make_causal_mask, - make_self_mask, - assert_allclose, - print_debug_tensor_stats, ) from transformer_engine.jax import fp8_autocast from transformer_engine.jax.attention import ( is_fused_attn_kernel_available, - fused_attn, AttnBiasType, AttnMaskType, QKVLayout, @@ -36,10 +23,11 @@ CPStrategy, ) from transformer_engine.jax.sharding import MeshResource +import pytest -from test_fused_attn import FusedAttnRunner, BiasShape, general_dot_product_attention, make_mask +from test_fused_attn import FusedAttnRunner, BiasShape, SeqDescFormat -DTYPES = [jnp.float16, jnp.bfloat16] +DTYPES = [jnp.bfloat16] class TestDistributedSelfAttn: @@ -141,6 +129,7 @@ def test_self_attn( QKVLayout.BS3HD, bias_shape, None, + SeqDescFormat.Seqlens, number_of_devices=device_count, mesh_shape=mesh_shape, mesh_axes=mesh_axes, @@ -205,6 +194,7 @@ def test_cross_attn( QKVLayout.BSHD_BS2HD, bias_shape, None, + SeqDescFormat.Seqlens, number_of_devices=device_count, mesh_shape=mesh_shape, mesh_axes=mesh_axes, @@ -293,6 +283,7 @@ def impl_test_context_parallel_attn( qkv_layout, bias_shape, None, + SeqDescFormat.Seqlens, number_of_devices=device_count, mesh_shape=mesh_shape, mesh_axes=mesh_axes, diff --git a/tests/jax/test_fused_attn.py b/tests/jax/test_fused_attn.py index 710ae1946d..beaf18cea3 100644 --- a/tests/jax/test_fused_attn.py +++ b/tests/jax/test_fused_attn.py @@ -2,7 +2,7 @@ # # See LICENSE for license information. """Tests for fused attention""" -from enum import Enum +from enum import Enum, auto from dataclasses import dataclass, field from functools import partial from math import sqrt @@ -28,12 +28,11 @@ AttnBiasType, AttnMaskType, QKVLayout, - QKVFormat, reorder_causal_load_balancing, inverse_reorder_causal_load_balancing, fused_attn, - fused_attn_thd, make_swa_mask, + SequenceDescriptor, CPStrategy, ) from transformer_engine.jax.cpp_extensions import FusedAttnHelper @@ -199,8 +198,8 @@ def _find_offsets(x): ).squeeze(-1) offsets = _find_offsets(segment_ids) - offsets = jnp.insert(offsets, -1, values=-1, axis=-1) - seqlens = jnp.insert(seqlens, -1, values=0, axis=-1) + offsets = jnp.insert(offsets, offsets.shape[-1], values=-1, axis=-1) + seqlens = jnp.insert(seqlens, seqlens.shape[-1], values=0, axis=-1) seqlens = jnp.where(seqlens, seqlens, -1) return seqlens, offsets @@ -239,11 +238,7 @@ def customcall_fused_dpa( key, value, bias, - mask, - seqlens_q, - seqlens_kv, - offsets_q, - offsets_kv, + sequence_descriptor, dropout_rng, **kwargs, ): @@ -264,19 +259,9 @@ def customcall_fused_dpa( qkv_args = (query, key, value) case _: raise ValueError(f"Unsupported {qkv_layout=}") - if not qkv_layout.is_thd(): - kwargs.pop("max_segments_per_seq") - return fused_attn(qkv_args, bias, mask, dropout_rng, **kwargs).astype(query.dtype) - return fused_attn_thd( - qkv_args, - bias, - seqlens_q, - seqlens_kv, - offsets_q, - offsets_kv, - dropout_rng, - **kwargs, - ).astype(query.dtype) + return fused_attn(qkv_args, bias, sequence_descriptor, dropout_rng, **kwargs).astype( + query.dtype + ) class BiasShape(Enum): @@ -290,6 +275,12 @@ class BiasShape(Enum): _11SS = "11SS" +class SeqDescFormat(Enum): + Mask = auto() + Seqlens = auto() + SegmentIDs = auto() + + @dataclass class FusedAttnRunner: """ @@ -309,7 +300,8 @@ class FusedAttnRunner: is_training: bool qkv_layout: QKVLayout bias_shape: BiasShape - window_size: Optional[Tuple[int, int]] = None + window_size: Tuple[int, int] + seq_desc_format: SeqDescFormat # Specifies sharding resources for distributed tests number_of_devices: int = 1 @@ -327,11 +319,14 @@ class FusedAttnRunner: # See https://docs.nvidia.com/deeplearning/cudnn/latest/release-notes.html#cudnn-9-4-0 for known issue # generating zero-length ragged tensors. This setting adjusts the test to avoid the zero-length cases. def _get_max_segments_per_sequence(self): - if 90400 <= get_cudnn_version() < 90500: - return self.num_segments_per_seq + if self.qkv_layout.is_thd(): + if 90400 <= get_cudnn_version() < 90500: + return self.num_segments_per_seq + else: + # +1 for testing runtime_segments < max_segments + return self.num_segments_per_seq + 1 else: - # +1 for testing runtime_segments < max_segments - return self.num_segments_per_seq + 1 + return 1 def _check_configs(self): # TODO(rewang): probably adds this in is_fused_attn_available @@ -462,11 +457,11 @@ def generate_random_segment_ids( ): rng = np.random.default_rng(seed=seed) # [1, 1, 1, 2, 2, 3, 3, 3, 3, 0, 0], 0 means pad - segment_ids = np.zeros((batch_size, sequence_length), dtype=int) - segment_pos = np.zeros((batch_size, sequence_length), dtype=int) + segment_ids = np.zeros((batch_size, sequence_length), dtype=np.int32) + segment_pos = np.zeros((batch_size, sequence_length), dtype=np.int32) # [0, 1, 2, 0, 1, 0, 1, 2, 3, 0, 0] # [0, 0, 1, 0, 1, 0, 0, 1, 1, 1, 1], 1 means pad - segment_pad = np.zeros((batch_size, sequence_length), dtype=int) + segment_pad = np.zeros((batch_size, sequence_length), dtype=np.int32) # Not include paddings max_segment_size = sequence_length // num_segments @@ -541,16 +536,47 @@ def generate_random_segment_ids( self.window_size, ) + # Test different input formats if self.qkv_layout.is_thd(): - self.mask_for_customcall = None # THD format doesn't support mask + match self.seq_desc_format: + case SeqDescFormat.Mask: + pytest.skip("THD doesn't support mask input") + case SeqDescFormat.Seqlens: + self.sequence_desciptor = SequenceDescriptor.from_seqlens_and_offsets( + (self.seqlens_q, self.seqlens_kv), + (self.offsets_q, self.offsets_kv), + ) + case SeqDescFormat.SegmentIDs: + self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos( + (self.segment_ids_q, self.segment_ids_kv), + (self.segment_pos_q, self.segment_pos_kv), + ) + case _: + raise ValueError(f"Unknown {self.seq_desc_format=}") else: - self.mask_for_customcall = make_mask( - self.segment_ids_q, - self.segment_ids_kv, - self.segment_pos_q, - self.segment_pos_kv, - self.attn_mask_type, - ) + match self.seq_desc_format: + case SeqDescFormat.Mask: + self.sequence_desciptor = make_mask( + self.segment_ids_q, + self.segment_ids_kv, + self.segment_pos_q, + self.segment_pos_kv, + self.attn_mask_type, + ) + case SeqDescFormat.Seqlens: + self.sequence_desciptor = SequenceDescriptor.from_seqlens( + ( + self.segment_ids_q.sum(axis=-1).astype(jnp.int32), + self.segment_ids_kv.sum(axis=-1).astype(jnp.int32), + ), + ) + case SeqDescFormat.SegmentIDs: + self.sequence_desciptor = SequenceDescriptor.from_segment_ids_and_pos( + (self.segment_ids_q, self.segment_ids_kv), + None, + ) + case _: + raise ValueError(f"Unknown {self.seq_desc_format=}") self.dropout_rng = dropout_key if self.dropout_prob > 0 else None self.scaling_factor = 1.0 / sqrt(self.head_dim) @@ -565,10 +591,21 @@ def generate_random_segment_ids( ) self.qkvo_sharding = NamedSharding(self.mesh, self.qkvo_psec) - self.mask_pspec = PartitionSpec( + mask_pspec = PartitionSpec( self.mesh_resource.dp_resource, None, self.mesh_resource.cp_resource, None ) - self.mask_sharding = NamedSharding(self.mesh, self.mask_pspec) + self.mask_sharding = NamedSharding(self.mesh, mask_pspec) + + match self.seq_desc_format: + case SeqDescFormat.Mask: + self.seq_desc_sharding = self.mask_sharding + case _: + + def to_dp_shardings(x): + pspec = PartitionSpec(self.mesh_resource.dp_resource) + return NamedSharding(self.mesh, pspec) + + self.seq_desc_sharding = jax.tree.map(to_dp_shardings, self.sequence_desciptor) if self.bias_shape == BiasShape._1HSS: self.bias_pspec = PartitionSpec( @@ -631,11 +668,7 @@ def test_forward(self): jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), jax.device_put(self.bias, self.bias_sharding), - jax.device_put(self.mask_for_customcall, self.mask_sharding), - jax.device_put(self.seqlens_q, self.seq_length_offset_sharding), - jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding), - jax.device_put(self.offsets_q, self.seq_length_offset_sharding), - jax.device_put(self.offsets_kv, self.seq_length_offset_sharding), + jax.device_put(self.sequence_desciptor, self.seq_desc_sharding), jax.device_put(self.dropout_rng, self.dropout_rng_sharding), ] kwargs = { @@ -659,11 +692,7 @@ def test_forward(self): self.qkvo_sharding, self.qkvo_sharding, self.bias_sharding, - self.mask_sharding, - self.seq_length_offset_sharding, - self.seq_length_offset_sharding, - self.seq_length_offset_sharding, - self.seq_length_offset_sharding, + self.seq_desc_sharding, self.dropout_rng_sharding, ], ) @@ -722,11 +751,7 @@ def grad_func(func, *args, **kwargs): jax.device_put(self.cp_reorder_fn(self.k), self.qkvo_sharding), jax.device_put(self.cp_reorder_fn(self.v), self.qkvo_sharding), jax.device_put(self.bias, self.bias_sharding), - jax.device_put(self.mask_for_customcall, self.mask_sharding), - jax.device_put(self.seqlens_q, self.seq_length_offset_sharding), - jax.device_put(self.seqlens_kv, self.seq_length_offset_sharding), - jax.device_put(self.offsets_q, self.seq_length_offset_sharding), - jax.device_put(self.offsets_kv, self.seq_length_offset_sharding), + jax.device_put(self.sequence_desciptor, self.seq_desc_sharding), jax.device_put(self.dropout_rng, self.dropout_rng_sharding), ] kwargs = { @@ -768,11 +793,7 @@ def grad_func(func, *args, **kwargs): self.qkvo_sharding, self.qkvo_sharding, self.bias_sharding, - self.mask_sharding, - self.seq_length_offset_sharding, - self.seq_length_offset_sharding, - self.seq_length_offset_sharding, - self.seq_length_offset_sharding, + self.seq_desc_sharding, self.dropout_rng_sharding, ), out_shardings=(None, grad_shardings), @@ -883,10 +904,7 @@ def check_dqkv(primitive, reference, pad, idx): @pytest.mark.parametrize( "b, s_q, s_kv, h_q, h_kv, d, dtype", [ - pytest.param(4, 128, 128, 16, 16, 64, jnp.bfloat16, id="4-128-128-16-16-64-BF16-SELF"), - pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"), pytest.param(2, 2048, 2048, 12, 12, 64, jnp.bfloat16, id="2-2048-2048-12-12-64-BF16-SELF"), - pytest.param(4, 128, 256, 16, 16, 64, jnp.bfloat16, id="4-128-256-16-16-64-BF16-CROSS"), pytest.param( 2, 2048, @@ -897,8 +915,8 @@ def check_dqkv(primitive, reference, pad, idx): jnp.bfloat16, id="2-2048-1024-12-12-64-BF16-CROSS", ), - pytest.param(4, 128, 128, 16, 8, 64, jnp.bfloat16, id="4-128-128-16-8-64-BF16-GQA"), pytest.param(2, 2048, 2048, 12, 6, 64, jnp.bfloat16, id="2-2048-2048-12-6-64-BF16-GQA"), + pytest.param(4, 128, 128, 16, 16, 64, jnp.float16, id="4-128-128-16-16-64-FP16-SELF"), ], ) @pytest.mark.parametrize( @@ -915,6 +933,14 @@ def check_dqkv(primitive, reference, pad, idx): pytest.param(True, id="SWA"), ], ) +@pytest.mark.parametrize( + "seq_desc_format", + [ + pytest.param(SeqDescFormat.Mask, id="Mask"), + pytest.param(SeqDescFormat.Seqlens, id="Seqlens"), + pytest.param(SeqDescFormat.SegmentIDs, id="SegmentIDs"), + ], +) class TestFusedAttn: """ Fused attention tester @@ -953,6 +979,7 @@ def _test_forward( qkv_layout, bias_shape, swa, + seq_desc_format, ): """ Test forward with parameterized configs @@ -977,6 +1004,7 @@ def _test_forward( qkv_layout, bias_shape, window_size, + seq_desc_format, ) runner.test_forward() @@ -1002,6 +1030,7 @@ def test_backward( qkv_layout, bias_shape, swa, + seq_desc_format, ): """ Test backward with parameterized configs @@ -1024,5 +1053,6 @@ def test_backward( qkv_layout, bias_shape, window_size, + seq_desc_format, ) runner.test_backward() diff --git a/transformer_engine/jax/attention.py b/transformer_engine/jax/attention.py index 7b6c605236..09128b013b 100644 --- a/transformer_engine/jax/attention.py +++ b/transformer_engine/jax/attention.py @@ -2,13 +2,16 @@ # # See LICENSE for license information. """JAX multi-head attention modules""" - +from __future__ import annotations from enum import Enum from functools import partial -from typing import Optional, Tuple +from typing import Optional, Tuple, Union +import warnings + from jax.ad_checkpoint import checkpoint_name import jax import jax.numpy as jnp +from flax.linen import make_attention_mask from transformer_engine.transformer_engine_jax import NVTE_Bias_Type from transformer_engine.transformer_engine_jax import NVTE_Mask_Type @@ -252,28 +255,24 @@ def make_helper(attn_mask_type): (-1, -1) if window_size is None else window_size, ) - if not make_helper(attn_mask_type).is_fused_attn_kernel_available(): - return False - - return True + return make_helper(attn_mask_type).is_fused_attn_kernel_available() def _obtain_batch_and_max_seqlen(qkv, qkv_layout): - match qkv_layout: - case QKVLayout.BS3HD | QKVLayout.T3HD: - assert len(qkv) == 1, f"qkv must be (qkvpacked,) with {qkv_layout=}" - batch, q_max_seqlen, *_ = qkv[0].shape - kv_max_seqlen = q_max_seqlen - case QKVLayout.BSHD_BS2HD | QKVLayout.THD_T2HD: - assert len(qkv) == 2, f"qkv must be (query, kvpacked) with {qkv_layout=}" - batch, q_max_seqlen, *_ = qkv[0].shape - kv_max_seqlen = qkv[1].shape[1] - case QKVLayout.BSHD_BSHD_BSHD | QKVLayout.THD_THD_THD: - assert len(qkv) == 3, f"qkv must be (query, key, value) with {qkv_layout=}" - batch, q_max_seqlen, *_ = qkv[0].shape - kv_max_seqlen = qkv[1].shape[1] - case _: - raise ValueError(f"Unsupported {qkv_layout=}") + if qkv_layout.is_qkvpacked(): + assert len(qkv) == 1, f"qkv must be (qkvpacked,) with {qkv_layout=}" + batch, q_max_seqlen, *_ = qkv[0].shape + kv_max_seqlen = q_max_seqlen + elif qkv_layout.is_kvpacked(): + assert len(qkv) == 2, f"qkv must be (query, kvpacked) with {qkv_layout=}" + batch, q_max_seqlen, *_ = qkv[0].shape + kv_max_seqlen = qkv[1].shape[1] + elif qkv_layout.is_separate(): + assert len(qkv) == 3, f"qkv must be (query, key, value) with {qkv_layout=}" + batch, q_max_seqlen, *_ = qkv[0].shape + kv_max_seqlen = qkv[1].shape[1] + else: + raise ValueError(f"Unsupported {qkv_layout=}") return batch, q_max_seqlen, kv_max_seqlen @@ -289,7 +288,273 @@ def inverse_reorder_causal_load_balancing(tensor, cp_size: int, tensor_format: Q return tex.attention.reorder_causal_load_balancing(tensor, cp_size, seq_dim, True) -def fused_attn( +def _get_seqlens_and_offsets(segment_ids, max_segments_per_seq): + # bincount map with 0s + bincount_vmap = jax.vmap(partial(jnp.bincount, length=max_segments_per_seq + 1)) + seqlens_with_zero = bincount_vmap(segment_ids.astype(jnp.int32)) + seqlens = seqlens_with_zero[..., 1:] + + def _find_offsets(x): + same_as_previous = jnp.logical_and(x[..., 1:] != x[..., :-1], x[..., 1:] != 0) + first_column = x[..., :1] != 0 + same_as_previous = jnp.hstack((first_column, same_as_previous)) + return jax.vmap(partial(jnp.argwhere, size=(max_segments_per_seq + 1), fill_value=-1))( + same_as_previous + ).squeeze(-1) + + offsets = _find_offsets(segment_ids) + return seqlens, offsets + + +def _mask_to_seqlens_offset(mask, max_segments_per_seq): + assert mask.shape[1] == 1 + row_ids = mask.squeeze(axis=1).max(axis=-1) + q_seqlen, q_offset = _get_seqlens_and_offsets(row_ids, max_segments_per_seq) + col_ids = mask.squeeze(axis=1).max(axis=-2) + kv_seqlen, kv_offset = _get_seqlens_and_offsets(col_ids, max_segments_per_seq) + return q_seqlen, q_offset, kv_seqlen, kv_offset + + +def _segment_ids_pos_to_seqlens_offsets( + segment_ids_q, + segment_ids_kv, + segment_pos_q, + segment_pos_kv, + attn_mask_type, + window_size, + max_segments_per_seq, +): + # (1 = attend, 0 = masked) + segment_mask = make_attention_mask( + segment_ids_q, + segment_ids_kv, + jnp.equal, + ) + segment_mask_with_id = make_attention_mask( + segment_ids_q, + segment_ids_kv, + lambda x, y: jnp.equal(x, y) * x, + ) + attn_mask = segment_mask + if attn_mask_type.is_causal(): + causal_mask = make_attention_mask( + segment_pos_q, + segment_pos_kv, + jnp.greater_equal, + ) + attn_mask = jnp.logical_and(segment_mask, causal_mask) + + swa_mask = make_swa_mask(segment_pos_q, segment_pos_kv, window_size, dtype=jnp.bool) + attn_mask = jnp.logical_and(attn_mask, swa_mask) + + attn_mask_with_id = jnp.where(attn_mask, segment_mask_with_id, 0) + q_seqlen, q_offset, kv_seqlen, kv_offset = _mask_to_seqlens_offset( + attn_mask_with_id, max_segments_per_seq + ) + return q_seqlen, kv_seqlen, q_offset, kv_offset + + +def _segment_ids_to_seqlens(segment_ids_q, segment_ids_kv, attn_mask_type): + # convert the mask to seqlens, mask doesn't support ragged offsets + if not attn_mask_type.is_padding(): + q_max_seqlen = segment_ids_q.shape[-1] + kv_max_seqlen = segment_ids_kv.shape[-1] + q_seq_lens = jnp.full_like(q_max_seqlen, q_max_seqlen, dtype=jnp.int32) + kv_seq_lens = jnp.full_like(kv_max_seqlen, kv_max_seqlen, dtype=jnp.int32) + else: + q_seq_lens = jnp.sum(segment_ids_q, axis=-1).astype(jnp.int32) + kv_seq_lens = jnp.sum(segment_ids_kv, axis=-1).astype(jnp.int32) + return q_seq_lens, kv_seq_lens + + +@jax.tree_util.register_pytree_node_class +class SequenceDescriptor: + """A class to descibe the sequences with flexible initialization. + - SequenceDescriptor.from_seqlens + For non-THD (non-packed) cases, where each batch has only 1 sequence. + - SequenceDescriptor.from_seqlens_and_offsets + For THD (packed) cases, where each batch may have not only 1 sequence. + - SequenceDescriptor.from_segment_ids_and_pos + Experimental feature for THD (packed) cases with context parallelism. + """ + + seqlens: Optional[Tuple[jnp.ndarray, jnp.ndarray]] + seq_offsets: Optional[Tuple[jnp.ndarray, jnp.ndarray]] + segment_ids: Optional[Tuple[jnp.ndarray, jnp.ndarray]] + segment_pos: Optional[Tuple[jnp.ndarray, jnp.ndarray]] + + def __init__(self, seqlens=None, seq_offsets=None, segment_ids=None, segment_pos=None): + """ + Initialize to Tuple(jnp.zeros, jnp.zeros) because the primitive only accepts pure jax array + """ + self.seqlens = (jnp.zeros(0), jnp.zeros(0)) if seqlens is None else seqlens + self.seq_offsets = (jnp.zeros(0), jnp.zeros(0)) if seq_offsets is None else seq_offsets + self.segment_ids = (jnp.zeros(0), jnp.zeros(0)) if segment_ids is None else segment_ids + self.segment_pos = (jnp.zeros(0), jnp.zeros(0)) if segment_pos is None else segment_pos + + def tree_flatten(self): + """ + Flatten method to register as a pytree node + """ + return ((self.seqlens, self.seq_offsets, self.segment_ids, self.segment_pos), None) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """ + Unflatten method to register as a pytree node + """ + del aux_data + return cls(*children) + + def get_seqlens_and_offsets( + self, attn_mask_type, qkv_layout, window_size, max_segments_per_seq + ): + """ + Acquire the seqlens/offsets for cuDNN backend + """ + attn_mask_type = AttnMaskType(attn_mask_type) + qkv_layout = QKVLayout(qkv_layout) + q_segment_ids, kv_segment_ids = self.segment_ids + q_segment_pos, kv_segment_pos = self.segment_pos + assert q_segment_ids.shape == q_segment_pos.shape + assert kv_segment_ids.shape == kv_segment_pos.shape + # No segment_ids/segment_pos + if q_segment_ids.size + kv_segment_ids.size == 0: + return self.seqlens, self.seq_offsets + + if qkv_layout.is_thd(): + q_seqlens, kv_seqlens, q_offsets, kv_offsets = _segment_ids_pos_to_seqlens_offsets( + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, + attn_mask_type, + window_size, + max_segments_per_seq, + ) + else: + q_seqlens, kv_seqlens = _segment_ids_to_seqlens( + q_segment_ids, + kv_segment_ids, + attn_mask_type, + ) + q_offsets = kv_offsets = jnp.zeros(0) + return (q_seqlens, kv_seqlens), (q_offsets, kv_offsets) + + @classmethod + def _expand_to_pair( + cls, value: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]] + ) -> Tuple[jnp.ndarray, jnp.ndarray]: + """ + Internal helper to ensure a single value expands into a pair (q, kv). + """ + if isinstance(value, tuple): + if len(value) != 2: + raise ValueError("Input tuple must have exactly 2 elements.") + return value + + if isinstance(value, jnp.ndarray): + return value, value # Duplicate for q=kv case + + raise TypeError( + "Expected a jax.numpy.ndarray or a tuple of two jax.numpy.ndarray, " + f"but got {type(value).__name__}." + ) + + @classmethod + def from_seqlens( + cls, + seqlens: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], + ) -> SequenceDescriptor: + """ + Factory method for inputs with sequence lengths only (non-THD). + Args: + seqlens(Tuple(jnp.ndarray, jnp.ndarray)) = (q_seqlens, kv_seqlens): + - q_seqlens (jnp.ndarray): + Sequence lengths for the query, with shape [batch]. + - kv_seqlen (jnp.ndarray): + Sequence lengths for the key and value, with shape [batch]. + Return: + A SequenceDescriptor with only seqlens initialized. + """ + q_seqlens, kv_seqlens = cls._expand_to_pair(seqlens) + return cls(seqlens=(q_seqlens, kv_seqlens)) + + @classmethod + def from_seqlens_and_offsets( + cls, + seqlens: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], + seq_offsets: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], + ) -> SequenceDescriptor: + """ + Factory method for inputs with sequence lengths and offsets (THD). + Args: + seqlens(Tuple(jnp.ndarray, jnp.ndarray)) = (q_seqlens, kv_seqlens): + - q_seqlens (jnp.ndarray): + Sequence lengths for the query, with shape [batch, max_seqlen]. + Unused positions are padded with -1. + - kv_seqlen (jnp.ndarray): + Sequence lengths for the key and value, with shape [batch, max_seqlen]. + Unused positions are padded with -1. + seq_offsets(Tuple(jnp.ndarray, jnp.ndarray)) = (q_offsets, kv_offsets) + - q_seq_offsets (jnp.ndarray): + The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1]. + Unused positions are padded with -1. + - kv_seq_offsets (jnp.ndarray): + The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1]. + Unused positions are padded with -1. + Return: + A SequenceDescriptor with seqlens/seq_offsets initialized. + """ + q_seqlens, kv_seqlens = cls._expand_to_pair(seqlens) + q_offsets, kv_offsets = cls._expand_to_pair(seq_offsets) + return cls(seqlens=(q_seqlens, kv_seqlens), seq_offsets=(q_offsets, kv_offsets)) + + @classmethod + def from_segment_ids_and_pos( + cls, + segment_ids: Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]], + segment_pos: Optional[Union[jnp.ndarray, Tuple[jnp.ndarray, jnp.ndarray]]] = None, + ) -> SequenceDescriptor: + """ + Experimental factory method for inputs with segment IDs and optional positions. (THD) + Args: + segment_ids(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_ids, kv_segment_ids): + - q_segment_ids (jnp.ndarray): + Query segment ids start with 1, with shape [batch, max_seqlen]. + 0s are treated as paddings. + - kv_segment_ids (jnp.ndarray): + Key, value segment ids start with 1, with shape [batch, max_seqlen]. + 0s are treated as paddings. + segment_pos(Tuple(jnp.ndarray, jnp.ndarray)) = (q_segment_pos, kv_segment_pos) + - q_segment_pos (jnp.ndarray): + The position inside each segment for query, with shape [batch, max_seqlen]. + - kv_segment_pos (jnp.ndarray): + The position inside each segment for key, value, with shape [batch, max_seqlen]. + Return: + A SequenceDescriptor with segment_ids/segment_pos initialized. + """ + q_seg_ids, kv_seg_ids = cls._expand_to_pair(segment_ids) + + if segment_pos is not None: + segment_pos = cls._expand_to_pair(segment_pos) + else: + + def generate_default_pos(segment_ids): + seqlen = segment_ids.shape[-1] + return jnp.broadcast_to(jnp.arange(seqlen), segment_ids.shape) + + q_seg_pos = generate_default_pos(q_seg_ids) + kv_seg_pos = generate_default_pos(kv_seg_ids) + segment_pos = (q_seg_pos, kv_seg_pos) + + return cls( + segment_ids=(q_seg_ids, kv_seg_ids), + segment_pos=segment_pos, + ) + + +def _legacy_fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], mask: Optional[jnp.ndarray], @@ -372,10 +637,7 @@ def fused_attn( output = _fused_attn( qkv, bias, - q_seq_lens, - kv_seq_lens, - None, - None, + SequenceDescriptor.from_seqlens((q_seq_lens, kv_seq_lens)), seed, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, @@ -414,63 +676,13 @@ def fused_attn_thd( context_parallel_axis: str = "", ): """ - (Experimental) Perform THD (packed) cuDNN fused attention. - - This function implements the following formula: - BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 - Args: - qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors. - It supports three formats: - - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key, - and value have the same shape (e.g., self-attention). - - `(query, kv_packed)`: For separate query and KV packed format, typically used when - query has a different shape (e.g., cross-attention). - - `(query, key, value)`: For separate query, key, and value tensors. - bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores. - q_seqlen (jnp.ndarray): - Sequence lengths for the query, with shape [batch, max_seqlen]. Unused positions are - padded with -1. - kv_seqlen (jnp.ndarray): - Sequence lengths for the key and value, with shape [batch, max_seqlen]. Unused positions - are padded with -1. - q_seq_offsets (jnp.ndarray): - The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1]. - Unused positions are padded with -1. - kv_seq_offsets (jnp.ndarray): - The offsets in the sequence dim for the query, with shape [batch, max_seqlen + 1]. - Unused positions are padded with -1. - seed (Optional[jnp.ndarray]): Optional random seed for dropout. - attn_bias_type (NVTE_Bias_Type): Type of attention bias. - attn_mask_type (NVTE_Mask_Type): Type of attention mask. - qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors. - scaling_factor (float): Scaling factor for the attention scores. - dropout_probability (float): Dropout probability to apply during attention. - is_training (bool): Flag indicating whether the model is in training mode. - max_segments_per_seq (int): - Indicating the maximum number of segments inside a sequence. This parameter is to - constrain the limit usage and need to be static during the e2e training. The XLA compile - time and memory consumption is proportional to `max_segments_per_seq`. - window_size (Optional[Tuple[int, int]]): - Sliding window size. - context_parallel_causal_load_balanced (bool): - Indicates the sequences are ordered for causal mask load balancing when running context parallelism. - context_parallel_axis (str): The name of the context parallel axis. - Returns: - (jnp.ndarray): The output tensor from the fused attention. - - Examples: - >>> # segment_ids = [[1, 1, 2, 3], [1, 1, 2, 0]], 0 means padded tokens - >>> b, s, h, d = 2, 4, 12, 64 - >>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16) - >>> # 3 segments in first seq, 2 segments in second seq - >>> q_seq_lens = kv_seq_lens = jnp.asarray([[2, 1, 1, -1], [2, 1, -1, -1]]) - >>> # seq_offsets need to include the end offset of the last segments - >>> q_seq_offsets = kv_seq_offsets = jnp.asarray([[0, 2, 3, 4, -1], [0, 2, 3, -1, -1]]) - >>> out = fused_attn_thd((qkv,), None, q_seq_lens, kv_seq_lens, - q_seq_offsets, kv_seq_offsets, None, - AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK, - QKVLayout.T3HD, 0.125, 0, True, 3) + Deprecated THD fused attn, please use fusd_attn with SequenceDescriptor """ + warnings.warn( + "fused_attn_thd is deprecated, please use fused_attn with SequenceDescriptor", + DeprecationWarning, + ) + assert ( qkv_layout.is_thd() ), "Please use transformer_engine.jax.attention.fused_attn for non-THD format." @@ -497,10 +709,9 @@ def fused_attn_thd( output = _fused_attn( qkv, bias, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + SequenceDescriptor.from_seqlens_and_offsets( + (q_seq_lens, kv_seq_lens), (q_seq_offsets, kv_seq_offsets) + ), seed, attn_bias_type=attn_bias_type, attn_mask_type=attn_mask_type, @@ -518,15 +729,12 @@ def fused_attn_thd( return output -@partial(jax.custom_vjp, nondiff_argnums=(7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17)) +@partial(jax.custom_vjp, nondiff_argnums=(4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14)) def _fused_attn( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], - q_seq_lens: jnp.ndarray, - kv_seq_lens: jnp.ndarray, - q_seq_offsets: Optional[jnp.ndarray], - kv_seq_offsets: Optional[jnp.ndarray], - seed: jnp.ndarray, + sequence_descriptor: SequenceDescriptor, + seed: Optional[jnp.ndarray], attn_bias_type: AttnBiasType, attn_mask_type: AttnMaskType, qkv_layout: QKVLayout, @@ -542,10 +750,7 @@ def _fused_attn( output, _ = _fused_attn_fwd_rule( qkv, bias, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + sequence_descriptor, seed, attn_bias_type, attn_mask_type, @@ -565,10 +770,7 @@ def _fused_attn( def _fused_attn_fwd_rule( qkv, bias, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + sequence_descriptor, seed, attn_bias_type, attn_mask_type, @@ -585,10 +787,7 @@ def _fused_attn_fwd_rule( output, softmax_aux, rng_state = tex.fused_attn_fwd( qkv, bias, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + sequence_descriptor, seed, attn_bias_type=attn_bias_type.value, attn_mask_type=attn_mask_type.value, @@ -608,10 +807,7 @@ def _fused_attn_fwd_rule( return output, ( qkv, bias, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + sequence_descriptor, softmax_aux, rng_state, output, @@ -636,10 +832,7 @@ def _fused_attn_bwd_rule( ( qkv, bias, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + sequence_descriptor, softmax_aux, rng_state, output, @@ -651,10 +844,7 @@ def _fused_attn_bwd_rule( rng_state, output, dz, - q_seq_lens, - kv_seq_lens, - q_seq_offsets, - kv_seq_offsets, + sequence_descriptor, attn_bias_type=attn_bias_type.value, attn_mask_type=attn_mask_type.value, qkv_layout=qkv_layout.value, @@ -669,7 +859,137 @@ def _fused_attn_bwd_rule( ) if attn_bias_type == AttnBiasType.NO_BIAS: grad_bias = None - return grad_qkv, grad_bias, None, None, None, None, None + return ( + grad_qkv, + grad_bias, + None, + None, + ) _fused_attn.defvjp(_fused_attn_fwd_rule, _fused_attn_bwd_rule) + + +def fused_attn( + qkv: Tuple[jnp.ndarray, ...], + bias: Optional[jnp.ndarray], + sequence_descriptor: SequenceDescriptor, + seed: Optional[jnp.ndarray], + attn_bias_type: AttnBiasType, + attn_mask_type: AttnMaskType, + qkv_layout: QKVLayout, + scaling_factor: float, + dropout_probability: float, + is_training: bool, + max_segments_per_seq: int = 1, + window_size: Optional[Tuple[int, int]] = None, + context_parallel_strategy: CPStrategy = CPStrategy.DEFAULT, + context_parallel_causal_load_balanced: bool = False, + context_parallel_axis: str = "", +): + """ + Perform cuDNN fused attention. + + This function implements the following formula: + BMM1 -> (PreBias) -> ScaleMaskSoftmax -> (PostBias) -> (Dropout) -> BMM2 + Args: + qkv (Tuple[jnp.ndarray, ...]): A tuple containing query, key, and value tensors. + It supports three formats: + - `(qkv_packed,)`: For interleaved QKV packed format, typically used when query, key, + and value have the same shape (e.g., self-attention). + - `(query, kv_packed)`: For separate query and KV packed format, typically used when + query has a different shape (e.g., cross-attention). + - `(query, key, value)`: For separate query, key, and value tensors. + bias (Optional[jnp.ndarray]): An optional bias tensor to be added to the attention scores. + sequence_descriptor (SequenceDescriptor): Descriptor for how to describe the sequence. + seed (Optional[jnp.ndarray]): Optional random seed for dropout. + attn_bias_type (NVTE_Bias_Type): Type of attention bias. + attn_mask_type (NVTE_Mask_Type): Type of attention mask. + qkv_layout (NVTE_QKV_Layout): Layout of the QKV tensors. + scaling_factor (float): Scaling factor for the attention scores. + dropout_probability (float): Dropout probability to apply during attention. + is_training (bool): Flag indicating whether the model is in training mode. + max_segments_per_seq (int): + Indicating the maximum number of segments inside a sequence. This parameter is to + constrain the limit usage and need to be static during the e2e training. The XLA compile + time and memory consumption is proportional to `max_segments_per_seq`. + window_size (Optional[Tuple[int, int]]): + Sliding window size. + context_parallel_causal_load_balanced (bool): + Indicates the sequences are ordered for causal mask load balancing when running context parallelism. + context_parallel_axis (str): The name of the context parallel axis. + Returns: + (jnp.ndarray): The output tensor from the fused attention. + + Examples (non-THD, also known as non-packed): + >>> # q_segment_ids = [[1, 1, 1, 0], [1, 1, 0, 0]], 0 means padded tokens + >>> # kv_segment_ids = [[1, 0, 0, 0], [1, 1, 0, 0]], 0 means padded tokens + >>> b, s, h, d = 2, 4, 12, 64 + >>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16) + >>> q_seq_lens = jnp.asarray([3, 2]) + >>> kv_seq_lens = jnp.asarray([1, 2]) + >>> sequence_desc = SequenceDescriptor.from_seqlens( + seqlens=(q_seq_lens, kv_seq_lens)) + >>> out = fused_attn((qkv,), None, sequence_desc, None, + AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK, + QKVLayout.BS3HD, 0.125, 0, True, 3) + + Examples (THD, also known as packed): + >>> # segment_ids = [[1, 1, 2, 3], [1, 1, 2, 0]], 0 means padded tokens + >>> # segment_pos = [[0, 1, 0, 0], [0, 1, 0, 1]] + >>> b, s, h, d = 2, 4, 12, 64 + >>> qkv = jnp.zeros((b, s, 3, h, d), dtype=jnp.bfloat16) + >>> # 3 segments in first seq, 2 segments in second seq + >>> q_seq_lens = kv_seq_lens = jnp.asarray([[2, 1, 1, -1], [2, 1, -1, -1]]) + >>> # seq_offsets need to include the end offset of the last segments + >>> q_seq_offsets = kv_seq_offsets = jnp.asarray([[0, 2, 3, 4, -1], [0, 2, 3, -1, -1]]) + >>> sequence_desc = SequenceDescriptor.from_seqlens_and_offsets( + seqlens=(q_seq_lens, kv_seq_lens), + seq_offsets=(q_seq_offsets, kv_seq_offsets)) + >>> out = fused_attn((qkv,), None, sequence_desc, None, + AttnBiasType.NO_BIAS, AttnMaskType.PADDING_CAUSAL_MASK, + QKVLayout.T3HD, 0.125, 0, True, 3) + """ + if isinstance(sequence_descriptor, jnp.ndarray): + warnings.warn( + "Pass mask to fused_attn is deprecated, please use SequenceDescriptor instead. " + + "See help(transformer_engine.jax.attention.SequenceDescriptor) for details.", + DeprecationWarning, + ) + if max_segments_per_seq != 1: + raise ValueError("Passing mask is only supported for non-THD case.") + return _legacy_fused_attn( + qkv, + bias, + sequence_descriptor, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + window_size=window_size, + context_parallel_strategy=context_parallel_strategy, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, + ) + output = _fused_attn( + qkv, + bias, + sequence_descriptor, + seed, + attn_bias_type=attn_bias_type, + attn_mask_type=attn_mask_type, + qkv_layout=qkv_layout, + scaling_factor=scaling_factor, + dropout_probability=dropout_probability, + is_training=is_training, + max_segments_per_seq=max_segments_per_seq, + window_size=window_size, + context_parallel_strategy=context_parallel_strategy, + context_parallel_causal_load_balanced=context_parallel_causal_load_balanced, + context_parallel_axis=context_parallel_axis, + ) + + return output diff --git a/transformer_engine/jax/cpp_extensions/attention.py b/transformer_engine/jax/cpp_extensions/attention.py index 3a116ffb63..ae3cfddccc 100644 --- a/transformer_engine/jax/cpp_extensions/attention.py +++ b/transformer_engine/jax/cpp_extensions/attention.py @@ -17,7 +17,7 @@ from jax.sharding import PartitionSpec, NamedSharding from jax.extend import ffi -from transformer_engine.jax.attention import CPStrategy +from transformer_engine.jax.attention import CPStrategy, SequenceDescriptor from transformer_engine import transformer_engine_jax from transformer_engine.transformer_engine_jax import ( @@ -211,9 +211,8 @@ def generate_cu_seqlen(actual_seqlen): """ Generating cumsum seqlen for a batch """ - cu_seqlen = jnp.cumsum(actual_seqlen, axis=-1) - cu_seqlen = jnp.where(actual_seqlen < 0, -1, cu_seqlen) - cu_seqlen = jnp.insert(cu_seqlen, 0, values=0, axis=-1) + actual_seqlen = jnp.where(actual_seqlen < 0, 0, actual_seqlen) + cu_seqlen = jnp.cumulative_sum(actual_seqlen, include_initial=True) return cu_seqlen @@ -224,7 +223,7 @@ class FusedAttnFwdPrimitive(BasePrimitive): name = "te_fused_attn_forward" multiple_results = True - impl_static_args = (9,) + impl_static_args = (13,) inner_primitive = None outer_primitive = None @@ -234,11 +233,15 @@ def abstract( k_aval, v_aval, bias_aval, + seed_aval, q_seqlen_or_cu_seqlen_aval, kv_seqlen_or_cu_seqlen_aval, _q_seq_offsets, _k_seq_offsets, - seed_aval, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, *, config: _FusedAttnConfig, ): @@ -354,11 +357,15 @@ def lowering( k, v, bias, + seed, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, *, config: _FusedAttnConfig, ): @@ -387,11 +394,15 @@ def lowering( k, v, bias, + seed, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering input_batch=input_batch, bias_batch=bias_batch, q_max_seqlen=q_max_seqlen, @@ -417,11 +428,11 @@ def lowering( k, v, bias, + seed, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, - seed, ] operand_shapes = map(lambda x: x.type.shape, operands) out_types = [ @@ -466,15 +477,35 @@ def impl( k, v, bias, + seed, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config: _FusedAttnConfig, ): assert FusedAttnFwdPrimitive.inner_primitive is not None + sequence_descriptor = SequenceDescriptor( + seqlens=(q_seqlen, kv_seqlen), + seq_offsets=(q_seq_offsets, k_seq_offsets), + segment_ids=(_q_segment_ids, _kv_segment_ids), + segment_pos=(_q_segment_pos, _kv_segment_pos), + ) + + (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = ( + sequence_descriptor.get_seqlens_and_offsets( + config.attn_mask_type, + config.qkv_layout, + config.window_size, + config.max_segments_per_seq, + ) + ) + if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD: def _fix_len_take(x, condition, fill_value=-1): @@ -517,6 +548,7 @@ def convert_to_2d(offsets, batch, max_seqlen): fill_value = 0 else: fill_value = -1 + q_seqlen = _fix_len_take(q_seqlen, q_seqlen > 0, fill_value=fill_value) kv_seqlen = _fix_len_take(kv_seqlen, kv_seqlen > 0, fill_value=fill_value) @@ -524,15 +556,17 @@ def convert_to_2d(offsets, batch, max_seqlen): # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]] q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen) k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen) + # Gather valid q_seq_offsets, which is greater and equal to 0 # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]] - q_seq_offsets = _fix_len_take(q_seq_offsets, q_seq_offsets >= 0) - k_seq_offsets = _fix_len_take(k_seq_offsets, k_seq_offsets >= 0) - - # Set the unused position to max size (batch * max_seqlen) + # And set the unused position to max size (batch * max_seqlen) # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]] - q_seq_offsets = jnp.where(q_seq_offsets < 0, q_batch * q_max_seqlen, q_seq_offsets) - k_seq_offsets = jnp.where(k_seq_offsets < 0, kv_batch * kv_max_seqlen, k_seq_offsets) + q_seq_offsets = _fix_len_take( + q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen + ) + k_seq_offsets = _fix_len_take( + k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen + ) q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) @@ -542,11 +576,15 @@ def convert_to_2d(offsets, batch, max_seqlen): k, v, bias, + seed, q_cu_seqlen, kv_cu_seqlen, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=config, ) return output, softmax_aux, rng_state @@ -555,7 +593,7 @@ def convert_to_2d(offsets, batch, max_seqlen): def batcher(batched_args, batch_dims, *, config): check_valid_batch_dims(batch_dims) assert FusedAttnFwdPrimitive.outer_primitive is not None - q_bdim, *_, seed_bdim = batch_dims + q_bdim, _, _, _, seed_bdim, *_ = batch_dims out_bdims = q_bdim, q_bdim, seed_bdim return ( @@ -600,7 +638,9 @@ def partition(config, mesh, arg_infos, result_infos): rng_state_sharding = seed_sharding = NamedSharding( mesh, PartitionSpec(get_all_mesh_axes(), None) ) - arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[4] = seed_sharding + arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) impl = partial(FusedAttnFwdPrimitive.impl, config=config) return mesh, impl, out_shardings, arg_shardings @@ -616,7 +656,7 @@ class FusedAttnBwdPrimitive(BasePrimitive): name = "te_fused_attn_backward" multiple_results = True - impl_static_args = (12,) + impl_static_args = (16,) inner_primitive = None outer_primitive = None @@ -634,6 +674,10 @@ def abstract( kv_seqlen_or_cu_seqlen_aval, _q_seq_offsets, _k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, *, config, ): @@ -718,6 +762,10 @@ def lowering( kv_cu_seqlen, q_seq_offsets, k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, *, config, ): @@ -754,6 +802,10 @@ def lowering( kv_cu_seqlen, q_seq_offsets, k_seq_offsets, + q_segment_ids, + kv_segment_ids, + q_segment_pos, + kv_segment_pos, # ffi_lowering needs number of parameters meets primitive.lowering input_batch=input_batch, bias_batch=bias_batch, q_max_seqlen=q_max_seqlen, @@ -839,10 +891,30 @@ def impl( kv_seqlen, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config, ): assert FusedAttnBwdPrimitive.inner_primitive is not None + sequence_descriptor = SequenceDescriptor( + seqlens=(q_seqlen, kv_seqlen), + seq_offsets=(q_seq_offsets, k_seq_offsets), + segment_ids=(_q_segment_ids, _kv_segment_ids), + segment_pos=(_q_segment_pos, _kv_segment_pos), + ) + + (q_seqlen, kv_seqlen), (q_seq_offsets, k_seq_offsets) = ( + sequence_descriptor.get_seqlens_and_offsets( + config.attn_mask_type, + config.qkv_layout, + config.window_size, + config.max_segments_per_seq, + ) + ) + if nvte_get_qkv_format(config.qkv_layout) == NVTE_QKV_Format.NVTE_THD: def _fix_len_take(x, condition, fill_value=-1): @@ -893,15 +965,17 @@ def convert_to_2d(offsets, batch, max_seqlen): # max_seqlen = 8, [[0, 3, 5, -1], [0, 2, 4, -1]] -> [[0, 3, 5, -1], [8, 11, 13, -1]] q_seq_offsets = convert_to_2d(q_seq_offsets, q_batch, q_max_seqlen) k_seq_offsets = convert_to_2d(k_seq_offsets, kv_batch, kv_max_seqlen) + # Gather valid q_seq_offsets, which is greater and equal to 0 # [[0, 3, 5, -1], [8, 11, 13, -1]] -> [[0, 3, 5, 8], [11, 13, -1, -1]] - q_seq_offsets = _fix_len_take(q_seq_offsets, q_seq_offsets >= 0) - k_seq_offsets = _fix_len_take(k_seq_offsets, k_seq_offsets >= 0) - - # Set the unused position to max size (batch * max_seqlen) + # And set the unused position to max size (batch * max_seqlen) # [[0, 3, 5, 8], [11, 13, -1, -1]] -> [[0, 3, 5, 8], [11, 13, b*s, b*s]] - q_seq_offsets = jnp.where(q_seq_offsets < 0, q_batch * q_max_seqlen, q_seq_offsets) - k_seq_offsets = jnp.where(k_seq_offsets < 0, kv_batch * kv_max_seqlen, k_seq_offsets) + q_seq_offsets = _fix_len_take( + q_seq_offsets, q_seq_offsets >= 0, fill_value=q_batch * q_max_seqlen + ) + k_seq_offsets = _fix_len_take( + k_seq_offsets, k_seq_offsets >= 0, fill_value=kv_batch * kv_max_seqlen + ) q_cu_seqlen = generate_cu_seqlen(q_seqlen.flatten()) kv_cu_seqlen = generate_cu_seqlen(kv_seqlen.flatten()) @@ -919,6 +993,10 @@ def convert_to_2d(offsets, batch, max_seqlen): kv_cu_seqlen, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=config, ) return dq, dk, dv, dbias @@ -975,6 +1053,10 @@ def sharded_impl( kv_cu_seqlen, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, ): local_dq, local_dk, local_dv, local_dbias = FusedAttnBwdPrimitive.impl( q, @@ -989,6 +1071,10 @@ def sharded_impl( kv_cu_seqlen, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=config, ) global_dbias = local_dbias @@ -1240,10 +1326,26 @@ def partition(config, mesh, arg_infos, result_infos): rng_state_sharding = seed_sharding = NamedSharding( mesh, PartitionSpec(get_all_mesh_axes(), None) ) - arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[4] = seed_sharding + arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) - def impl(q, k, v, bias, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, seed): + def impl( + q, + k, + v, + bias, + seed, + q_seqlen, + kv_seqlen, + q_seq_offsets, + k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + ): cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) @@ -1280,11 +1382,15 @@ def _cross_attn(idx, q, k, v, bias, q_seqlen, kv_seqlen, seed): k_unmasked, v_unmasked, bias, + seed, q_seqlen_for_step, kv_seqlen_for_step, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=helper.get_step_config(), ) results.append((output, softmax_aux, rng_state)) @@ -1357,13 +1463,31 @@ def impl( kv_seqlen, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, ): cp_size = get_mesh_axis_size(config.cp_axis, mesh) cp_rank = get_mesh_axis_rank(config.cp_axis, mesh) # See comment in FusedAttnCPFwdPrimitive.partition for why we define this function. def _cross_attn_bwd( - idx, q, k, v, bias, softmax_aux, rng_state, output, doutput, q_seqlen, kv_seqlen + idx, + q, + k, + v, + bias, + softmax_aux, + rng_state, + output, + doutput, + q_seqlen, + kv_seqlen, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, ): kv_max_seqlen = k.shape[1] kv_seqlen_per_subrank = kv_max_seqlen // (cp_size * 2) @@ -1402,6 +1526,10 @@ def _cross_attn_bwd( kv_seqlen_for_step, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=helper.get_step_config(), ) @@ -1433,6 +1561,10 @@ def _cross_attn_bwd( doutput, q_seqlen, kv_seqlen, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, ) for idx in range(cp_size) ] @@ -1595,7 +1727,9 @@ def partition(config, mesh, arg_infos, result_infos): rng_state_sharding = seed_sharding = NamedSharding( mesh, PartitionSpec(get_all_mesh_axes(), None) ) - arg_shardings = tuple([arg_i.sharding for arg_i in arg_infos[:-1]] + [seed_sharding]) + arg_shardings = [arg_i.sharding for arg_i in arg_infos] + arg_shardings[4] = seed_sharding + arg_shardings = tuple(arg_shardings) out_shardings = (out_sharding, softmax_aux_sharding, rng_state_sharding) def ring_attn_fwd_impl( @@ -1603,11 +1737,15 @@ def ring_attn_fwd_impl( k, v, bias, + seed, q_seqlen, kv_seqlen, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, ): _not_used = jnp.zeros(0, dtype=v.dtype) @@ -1644,12 +1782,16 @@ def mask_compute(attn_mask_type): kv, _not_used, bias, + seed, q_seqlen_per_step, kv_seqlen_per_step, q_seq_offsets, k_seq_offsets, - seed, - helper.get_step_config(attn_mask_type), + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, + config=helper.get_step_config(attn_mask_type), ) return output_per_step, softmax_aux_per_step @@ -1665,11 +1807,15 @@ def half_kv_no_mask_compute(): kv_part, _not_used, bias, + seed, q_seqlen_per_step, kv_seqlen_per_step, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), ) return output_per_step, softmax_aux_per_step @@ -1683,11 +1829,15 @@ def half_q_no_mask_compute(): kv, _not_used, bias, + seed, q_seqlen_per_step, kv_seqlen_per_step, q_seq_offsets, k_seq_offsets, - seed, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), ) output_per_step = jnp.concat([jnp.zeros_like(q_part), output_per_step], axis=1) @@ -1805,6 +1955,10 @@ def ring_attn_bwd_impl( kv_seqlen, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, ): _not_used = jnp.zeros(0, dtype=output.dtype) @@ -1849,6 +2003,10 @@ def mask_compute(attn_mask_type): kv_seqlen_per_step, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=helper.get_step_config(attn_mask_type), ) return dq_per_step, dk_dv_per_step, dbias_per_step @@ -1873,6 +2031,10 @@ def half_kv_no_mask_compute(): kv_seqlen_per_step, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), ) dk_dv_per_step = jnp.concat( @@ -1907,6 +2069,10 @@ def half_q_no_mask_compute(): kv_seqlen_per_step, q_seq_offsets, k_seq_offsets, + _q_segment_ids, + _kv_segment_ids, + _q_segment_pos, + _kv_segment_pos, config=helper.get_step_config(NVTE_Mask_Type.NVTE_NO_MASK), ) dq_per_step = jnp.concat([jnp.zeros_like(dq_per_step), dq_per_step], axis=1) @@ -1975,10 +2141,7 @@ def _maybe_context_parallel_axis(cp_axis: str): def fused_attn_fwd( qkv: Tuple[jnp.ndarray, ...], bias: Optional[jnp.ndarray], - q_seqlen: jnp.ndarray, - kv_seqlen: jnp.ndarray, - q_seq_offsets: Optional[jnp.ndarray], - kv_seq_offsets: Optional[jnp.ndarray], + sequence_descriptor: SequenceDescriptor, seed: Optional[jnp.ndarray], attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, @@ -2031,14 +2194,9 @@ def fused_attn_fwd( (jnp.ndarray): The output tensor from the fused attention. """ seed = _FusedAttnRNGStateChecker().check_seed(seed, dropout_probability, is_training) - - assert (q_seq_offsets is None) == ( - kv_seq_offsets is None - ), "Both q_seq_offsets and kv_seq_offsets must be either None or have values." - is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD - # For optional tensors, which custom calls doesn't support None _not_used = jnp.zeros(0, dtype=qkv[0].dtype) + match qkv_layout: case NVTE_QKV_Layout.NVTE_BS3HD | NVTE_QKV_Layout.NVTE_T3HD: assert len(qkv) == 1, f"qkv=(packed_qkv,) is expected with {qkv_layout=} but got {qkv=}" @@ -2071,21 +2229,19 @@ def fused_attn_fwd( cp_axis=_maybe_context_parallel_axis(context_parallel_axis), ) - primative = None + primitive = None match context_parallel_strategy: case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: - primative = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive + primitive = FusedAttnCPWithAllGatherFwdPrimitive.outer_primitive case CPStrategy.RING: - primative = FusedRingAttnFwdPrimitive.outer_primitive + primitive = FusedRingAttnFwdPrimitive.outer_primitive - return primative.bind( + seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) + return primitive.bind( *qkv_for_primitive, bias, - q_seqlen, - kv_seqlen, - q_seq_offsets if is_ragged else _not_used, - kv_seq_offsets if is_ragged else _not_used, seed, + *seq_desc_flatten, config=fused_config, ) @@ -2097,10 +2253,7 @@ def fused_attn_bwd( rng_state: jnp.ndarray, output: jnp.ndarray, doutput: jnp.ndarray, - q_seqlen: jnp.ndarray, - kv_seqlen: jnp.ndarray, - q_seq_offsets: Optional[jnp.ndarray], - kv_seq_offsets: Optional[jnp.ndarray], + sequence_descriptor: SequenceDescriptor, attn_bias_type: NVTE_Bias_Type, attn_mask_type: NVTE_Mask_Type, qkv_layout: NVTE_QKV_Layout, @@ -2155,12 +2308,6 @@ def fused_attn_bwd( same format as the input `qkv`. - The second value is the gradient with respect to `bias`, or `None` if `bias` is `None`. """ - - assert (q_seq_offsets is None) == ( - kv_seq_offsets is None - ), "Both q_seq_offsets and kv_seq_offsets must be either None or have values." - is_ragged = nvte_get_qkv_format(qkv_layout) == NVTE_QKV_Format.NVTE_THD - # For optional tensors, which custom calls doesn't support None _not_used = jnp.zeros(0, dtype=qkv[0].dtype) @@ -2196,24 +2343,23 @@ def fused_attn_bwd( cp_axis=_maybe_context_parallel_axis(context_parallel_axis), ) - primative = None + primitive = None match context_parallel_strategy: case CPStrategy.DEFAULT | CPStrategy.ALL_GATHER: - primative = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive + primitive = FusedAttnCPWithAllGatherBwdPrimitive.outer_primitive case CPStrategy.RING: - primative = FusedRingAttnBwdPrimitive.outer_primitive + primitive = FusedRingAttnBwdPrimitive.outer_primitive + + seq_desc_flatten, _ = jax.tree.flatten(sequence_descriptor) - *qkv_grads, bias_grad = primative.bind( + *qkv_grads, bias_grad = primitive.bind( *qkv_for_primitive, bias, softmax_aux, rng_state, output, doutput, - q_seqlen, - kv_seqlen, - q_seq_offsets if is_ragged else _not_used, - kv_seq_offsets if is_ragged else _not_used, + *seq_desc_flatten, config=fused_config, ) return tuple(qkv_grads[: len(qkv)]), bias_grad diff --git a/transformer_engine/jax/csrc/extensions/attention.cpp b/transformer_engine/jax/csrc/extensions/attention.cpp index dc857aa22c..7447cd1871 100644 --- a/transformer_engine/jax/csrc/extensions/attention.cpp +++ b/transformer_engine/jax/csrc/extensions/attention.cpp @@ -213,14 +213,14 @@ pybind11::tuple GetFusedAttnForwardWorkspaceSizes( auto layout_group = nvte_get_qkv_layout_group(qkv_layout); static void FusedAttnForwardImpl( - cudaStream_t stream, void *q, void *k, void *v, void *bias, void *q_cu_seqlens, - void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *seed, void *output, - void *softmax_aux, void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, - size_t q_max_seqlen, size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, - size_t bias_heads, size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, - float scaling_factor, float dropout_probability, NVTE_Bias_Type bias_type, - NVTE_Mask_Type mask_type, NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, - bool is_training, bool deterministic, int64_t window_size_left, int64_t window_size_right) { + cudaStream_t stream, void *q, void *k, void *v, void *bias, void *seed, void *q_cu_seqlens, + void *kv_cu_seqlens, void *q_seq_offsets, void *k_seq_offsets, void *output, void *softmax_aux, + void *rng_state, void *workspace, size_t input_batch, size_t bias_batch, size_t q_max_seqlen, + size_t kv_max_seqlen, size_t attn_heads, size_t num_gqa_groups, size_t bias_heads, + size_t head_dim, size_t max_segments_per_seq, size_t wkspace_size, float scaling_factor, + float dropout_probability, NVTE_Bias_Type bias_type, NVTE_Mask_Type mask_type, + NVTE_QKV_Layout qkv_layout, DType dtype, DType wkspace_dtype, bool is_training, + bool deterministic, int64_t window_size_left, int64_t window_size_right) { FUSED_ATTN_IMPL_COMMON_BLOCK; /* Input tensors */ @@ -303,11 +303,11 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s void *k = buffers[1]; void *v = buffers[2]; void *bias = buffers[3]; - void *q_cu_seqlens = buffers[4]; - void *kv_cu_seqlens = buffers[5]; - void *q_seq_offsets = is_ragged ? buffers[6] : nullptr; - void *k_seq_offsets = is_ragged ? buffers[7] : nullptr; - void *seed = buffers[8]; + void *seed = buffers[4]; + void *q_cu_seqlens = buffers[5]; + void *kv_cu_seqlens = buffers[6]; + void *q_seq_offsets = is_ragged ? buffers[7] : nullptr; + void *k_seq_offsets = is_ragged ? buffers[8] : nullptr; /* Output buffer from XLA */ void *output = buffers[9]; @@ -316,7 +316,7 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s void *workspace = buffers[12]; FusedAttnForwardImpl( - stream, q, k, v, bias, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, k_seq_offsets, seed, + stream, q, k, v, bias, seed, q_cu_seqlens, kv_cu_seqlens, q_seq_offsets, k_seq_offsets, output, softmax_aux, rng_state, workspace, descriptor.input_batch, descriptor.bias_batch, descriptor.q_max_seqlen, descriptor.kv_max_seqlen, descriptor.attn_heads, descriptor.num_gqa_groups, descriptor.bias_heads, descriptor.head_dim, @@ -354,24 +354,24 @@ void FusedAttnForward(cudaStream_t stream, void **buffers, const char *opaque, s DType wkspace_dtype = convert_ffi_datatype_to_te_dtype(workspace_buf->element_type()); Error_Type FusedAttnForwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_Type k_buf, - Buffer_Type v_buf, Buffer_Type bias_buf, + Buffer_Type v_buf, Buffer_Type bias_buf, Buffer_Type seed_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, - Buffer_Type seed_buf, Result_Type output_buf, + Variadic_Buffer_Type _unused_args, Result_Type output_buf, Result_Type softmax_aux_buf, Result_Type rng_state_buf, Result_Type workspace_buf, Dictionary attrs) { FUSED_ATTN_FFI_GET_ATTRS; FusedAttnForwardImpl( stream, q_buf.untyped_data(), k_buf.untyped_data(), v_buf.untyped_data(), - bias_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), kv_cu_seqlens_buf.untyped_data(), - is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, - is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, seed_buf.untyped_data(), - output_buf->untyped_data(), softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), - workspace_buf->untyped_data(), input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, - attn_heads, num_gqa_groups, bias_heads, head_dim, max_segments_per_seq, wkspace_size, - scaling_factor, dropout_probability, bias_type, mask_type, qkv_layout, dtype, wkspace_dtype, - is_training, deterministic, window_size_left, window_size_right); + bias_buf.untyped_data(), seed_buf.untyped_data(), q_cu_seqlens_buf.untyped_data(), + kv_cu_seqlens_buf.untyped_data(), is_ragged ? q_seq_offsets_buf.untyped_data() : nullptr, + is_ragged ? k_seq_offsets_buf.untyped_data() : nullptr, output_buf->untyped_data(), + softmax_aux_buf->untyped_data(), rng_state_buf->untyped_data(), workspace_buf->untyped_data(), + input_batch, bias_batch, q_max_seqlen, kv_max_seqlen, attn_heads, num_gqa_groups, bias_heads, + head_dim, max_segments_per_seq, wkspace_size, scaling_factor, dropout_probability, bias_type, + mask_type, qkv_layout, dtype, wkspace_dtype, is_training, deterministic, window_size_left, + window_size_right); return ffi_with_cuda_error_check(); } @@ -383,11 +383,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnForwardHandler, FusedAttnForwardFFI, .Arg() // k .Arg() // v .Arg() // bias + .Arg() // seed_buf .Arg() // q_cu_seqlens .Arg() // kv_cu_seqlens .Arg() // q_seq_offsets .Arg() // k_seq_offsets - .Arg() // seed_buf + .RemainingArgs() // _cp_aux_args unused .Ret() // output .Ret() // softmax_aux .Ret() // rng_state @@ -642,9 +643,9 @@ Error_Type FusedAttnBackwardFFI(cudaStream_t stream, Buffer_Type q_buf, Buffer_T Buffer_Type output_buf, Buffer_Type doutput_buf, Buffer_Type q_cu_seqlens_buf, Buffer_Type kv_cu_seqlens_buf, Buffer_Type q_seq_offsets_buf, Buffer_Type k_seq_offsets_buf, - Result_Type dq_buf, Result_Type dk_buf, Result_Type dv_buf, - Result_Type dbias_buf, Result_Type workspace_buf, - Dictionary attrs) { + Variadic_Buffer_Type _unused_args, Result_Type dq_buf, + Result_Type dk_buf, Result_Type dv_buf, Result_Type dbias_buf, + Result_Type workspace_buf, Dictionary attrs) { FUSED_ATTN_FFI_GET_ATTRS; FusedAttnBackwardImpl( @@ -677,6 +678,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(FusedAttnBackwardHandler, FusedAttnBackwardFFI, .Arg() // kv_cu_seqlens .Arg() // q_seq_offsets .Arg() // k_seq_offsets + .RemainingArgs() // _cp_aux_args unused .Ret() // dq .Ret() // dk .Ret() // dv