Skip to content

Commit

Permalink
[JAX] Fused attention unit tests fixes and refinements (#1352)
Browse files Browse the repository at this point in the history
* Add util functions to attn_mask_type

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Add util functions to qkv_layout

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Fix THD cross reference code

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Remove explicit segment_pad, encoding it to segment_ids

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Add jax.jit, replace _token with segment_ids, rename bias shape enum

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Add comment for make_mask

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Clean code

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Add doc strings for the added functions

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Remove cache for fa deterministic which causes UT failed

Signed-off-by: Reese Wang <rewang@nvidia.com>

* Rename fixture to avoid conflict

Signed-off-by: Reese Wang <rewang@nvidia.com>

---------

Signed-off-by: Reese Wang <rewang@nvidia.com>
  • Loading branch information
zlsh80826 authored Dec 17, 2024
1 parent f4f35c2 commit 7f5c784
Show file tree
Hide file tree
Showing 6 changed files with 201 additions and 152 deletions.
2 changes: 1 addition & 1 deletion tests/jax/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def clear_live_arrays():


@pytest.fixture(autouse=True, scope="module")
def enable_fused_attn():
def enable_fused_attn_after_hopper():
"""
Enable fused attn for hopper+ arch.
Fused attn kernels on pre-hopper arch are not deterministic.
Expand Down
6 changes: 2 additions & 4 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
from utils import (
make_causal_mask,
make_self_mask,
assert_tree_like_allclose,
assert_allclose,
print_debug_tensor_stats,
)
Expand All @@ -32,7 +31,6 @@
AttnMaskType,
QKVLayout,
QKVFormat,
get_qkv_format,
reorder_causal_load_balancing,
inverse_reorder_causal_load_balancing,
CPStrategy,
Expand Down Expand Up @@ -421,7 +419,7 @@ def impl_test_contex_parallel_attn(
dropout_prob = 0.0
is_training = True
dp_size, cp_size, tp_size = mesh_shape
qkv_format = get_qkv_format(qkv_layout)
qkv_format = qkv_layout.get_qkv_format()

batch, seqlen, num_head, hidden = data_shape

Expand Down Expand Up @@ -503,7 +501,7 @@ def grad_func(func, *args, **kwargs):
# Gradient is small, use a gradient multiplier to amplify the gradient
_, max_seq_len, num_heads, _ = data_shape
gradient_multiplier = max_seq_len * num_heads
if attn_mask_type in [AttnMaskType.CAUSAL_MASK, AttnMaskType.CAUSAL_BOTTOM_RIGHT_MASK]:
if attn_mask_type.is_causal():
gradient_multiplier /= 10
ret_valid = func(*args, **kwargs)
return (jnp.mean(ret_valid, dtype=jnp.float32) * gradient_multiplier).astype(dtype)
Expand Down
Loading

0 comments on commit 7f5c784

Please sign in to comment.