Skip to content

Commit

Permalink
Merge branch 'main' into backward_compatible_activation_recompute
Browse files Browse the repository at this point in the history
  • Loading branch information
ksivaman authored Mar 28, 2024
2 parents c3b28bb + c1a68f6 commit 4455c0a
Show file tree
Hide file tree
Showing 11 changed files with 1,028 additions and 1,973 deletions.
40 changes: 20 additions & 20 deletions tests/jax/test_distributed_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from utils import make_causal_mask, make_self_mask
from transformer_engine.jax import fp8_autocast
from transformer_engine.jax.fused_attn import is_fused_attn_kernel_available
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn
from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked
from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout

DTYPES = [jnp.float16, jnp.bfloat16]
Expand Down Expand Up @@ -86,15 +86,15 @@ def test_self_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, dat

def target_func(qkv, bias, mask):
return jnp.mean(
self_fused_attn(qkv,
bias,
mask,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training))
fused_attn_qkvpacked(qkv,
bias,
mask,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training))

def ref_func(qkv, bias, mask):
query, key, value = jnp.split(qkv, [1, 2], axis=-3)
Expand Down Expand Up @@ -192,16 +192,16 @@ def test_cross_attn(self, device_count, mesh_shape, mesh_axes, mesh_resource, da

def target_func(q, kv, mask):
return jnp.mean(
cross_fused_attn(q,
kv,
None,
mask,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training))
fused_attn_kvpacked(q,
kv,
None,
mask,
None,
attn_bias_type=attn_bias_type,
attn_mask_type=attn_mask_type,
scaling_factor=scaling_factor,
dropout_probability=dropout_prob,
is_training=is_training))

def ref_func(query, kv, mask):
key, value = jnp.split(kv, [1], axis=-3)
Expand Down
102 changes: 51 additions & 51 deletions tests/jax/test_fused_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
#
# See LICENSE for license information.
"""Tests for fused attention"""
import sys

from enum import Enum
from dataclasses import dataclass
from functools import partial
Expand All @@ -21,7 +19,7 @@
from jax.typing import ArrayLike, DTypeLike

from transformer_engine.jax.fused_attn import AttnBiasType, AttnMaskType, QKVLayout
from transformer_engine.jax.fused_attn import self_fused_attn, cross_fused_attn, fused_attn
from transformer_engine.jax.fused_attn import fused_attn_qkvpacked, fused_attn_kvpacked, fused_attn
from transformer_engine.jax.cpp_extensions import FusedAttnHelper

from transformer_engine_jax import NVTE_Fused_Attn_Backend
Expand Down Expand Up @@ -144,18 +142,22 @@ def customcall_fused_dpa(query, key, value, bias, q_token, kv_token, dropout_rng
case QKVLayout.BS3HD:
query, key, value = map(partial(jnp.expand_dims, axis=-3), [query, key, value])
qkv = jnp.concatenate((query, key, value), axis=-3)
return self_fused_attn(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
return fused_attn_qkvpacked(qkv, bias, mask, dropout_rng, **kwargs).astype(query.dtype)
case QKVLayout.BSHD_BS2HD:
key, value = map(partial(jnp.expand_dims, axis=-3), [key, value])
kv = jnp.concatenate((key, value), axis=-3)
return cross_fused_attn(query, kv, bias, mask, dropout_rng,
**kwargs).astype(query.dtype)
return fused_attn_kvpacked(query, kv, bias, mask, dropout_rng,
**kwargs).astype(query.dtype)
case QKVLayout.BSHD_BSHD_BSHD:
return fused_attn(query, key, value, bias, mask, dropout_rng,
**kwargs).astype(query.dtype)


class BiasShape(Enum):
"""
Enum class to represent the different bias shapes used in the fused attention.
"""

BIAS_1HSS = '1HSS'
BIAS_B1SS = 'B1SS'
BIAS_BHSS = 'BHSS'
Expand Down Expand Up @@ -188,17 +190,16 @@ def _check_configs(self):
if self.qkv_layout == QKVLayout.BS3HD and self.max_seqlen_q != self.max_seqlen_kv:
pytest.skip("BS3HD layout requires max_seqlen_q and max_seqlen_kv to be equal.")

self.backend = FusedAttnHelper(
self.dtype, self.dtype, self.qkv_layout.value, self.attn_bias_type.value,
self.attn_mask_type.value, self.dropout_prob, self.num_heads_q, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_kv, self.head_dim).get_fused_attn_backend()
self.backend = FusedAttnHelper(self.dtype, self.dtype, self.qkv_layout.value,
self.attn_bias_type.value, self.attn_mask_type.value,
self.dropout_prob, self.num_heads_q, self.num_heads_kv,
self.max_seqlen_q, self.max_seqlen_kv,
self.head_dim).get_fused_attn_backend()
if self.backend == NVTE_Fused_Attn_Backend.NVTE_No_Backend:
pytest.skip("Unsupported inputs combination or device compute capability.")

if self.bias_shape != BiasShape.BIAS_1HSS:
if self.attn_bias_type != AttnBiasType.POST_SCALE_BIAS:
pytest.skip("B1SS, BHSS and 11SS bias shapes require POST_SCALE_BIAS.")
elif self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
if self.attn_bias_type == AttnBiasType.POST_SCALE_BIAS:
if self.attn_mask_type not in [AttnMaskType.NO_MASK, AttnMaskType.CAUSAL_MASK]:
pytest.skip("B1SS, BHSS and 11SS bias shapes are only supported for "
"AttnMaskType.NO_MASK and AttnMaskType.CAUSAL_MASK.")
elif self.backend != NVTE_Fused_Attn_Backend.NVTE_F16_arbitrary_seqlen:
Expand All @@ -213,7 +214,9 @@ def _setup_inputs(self):
q_shape = (self.batch_size, self.max_seqlen_q, self.num_heads_q, self.head_dim)
k_shape = v_shape = (self.batch_size, self.max_seqlen_kv, self.num_heads_kv, self.head_dim)

if self.bias_shape == BiasShape.BIAS_1HSS:
if self.attn_bias_type == AttnBiasType.NO_BIAS:
bias_shape = None
elif self.bias_shape == BiasShape.BIAS_1HSS:
bias_shape = (1, self.num_heads_q, self.max_seqlen_q, self.max_seqlen_kv)
elif self.bias_shape == BiasShape.BIAS_B1SS:
bias_shape = (self.batch_size, 1, self.max_seqlen_q, self.max_seqlen_kv)
Expand All @@ -222,7 +225,7 @@ def _setup_inputs(self):
elif self.bias_shape == BiasShape.BIAS_11SS:
bias_shape = (1, 1, self.max_seqlen_q, self.max_seqlen_kv)
else:
pytest.xfail("PyTest attempted to use an unrecognized bias layout!")
pytest.fail(f"PyTest attempted to use an unrecognized bias_layout = {self.bias_shape}!")

self.q = jax.random.uniform(q_key, q_shape, self.dtype, -1.)
self.k = jax.random.uniform(k_key, k_shape, self.dtype, -1.)
Expand All @@ -237,7 +240,7 @@ def _setup_inputs(self):
cudnn_neg_inf = -2.**27. if self.dtype == jnp.bfloat16 else -2.**15.
self.bias = jnp.full(bias_shape, cudnn_neg_inf, dtype=self.dtype)
max_id = min(self.max_seqlen_q, self.max_seqlen_kv)
seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences
seq_id_size = max_id * 5 // 128 # 5 ids per interval of 128 sequences
seq_id = jax.random.randint(bias_key, (int(seq_id_size),), 0, max_id).tolist()
for i in range(1, len(seq_id)):
self.bias = \
Expand Down Expand Up @@ -327,8 +330,8 @@ def grad_func(func, *args, **kwargs):
**kwargs), arg_nums))
jitted_reference = jit(
value_and_grad(
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args,
**kwargs), arg_nums))
lambda q, k, v, bias, *args: grad_func(jax_dpa, q, k, v, bias, *args, **kwargs),
arg_nums))

primitive_out, primitive_dgrad = jitted_primitive(*args)
reference_out, reference_dgrad = jitted_reference(*args)
Expand Down Expand Up @@ -361,10 +364,10 @@ def check_dqkv(primitive, reference, valid_len):
primitive_dbias = jnp.float32(primitive_dgrad[3])
reference_dbias = jnp.float32(reference_dgrad[3])

assert_allclose(
primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
jnp.zeros_like(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:]),
dtype=self.dtype)
assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
jnp.zeros_like(primitive_dbias[..., self.valid_len_q:,
self.valid_len_kv:]),
dtype=self.dtype)

# dbias padded part
assert_allclose(primitive_dbias[..., self.valid_len_q:, self.valid_len_kv:],
Expand All @@ -376,15 +379,13 @@ def check_dqkv(primitive, reference, valid_len):
reference_dbias[..., :self.valid_len_q, :self.valid_len_kv],
dtype=self.dtype)

@pytest.mark.parametrize('bias_shape', [
pytest.param(BiasShape.BIAS_1HSS, id='1-H-S-S'),
pytest.param(BiasShape.BIAS_B1SS, id='B-1-S-S'),
pytest.param(BiasShape.BIAS_BHSS, id='B-H-S-S'),
pytest.param(BiasShape.BIAS_11SS, id='1-1-S-S'),
])
@pytest.mark.parametrize('attn_bias_type', [
pytest.param(AttnBiasType.NO_BIAS, id='NO_BIAS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, id='POST_SCALE_BIAS'),

@pytest.mark.parametrize('attn_bias_type, bias_shape', [
pytest.param(AttnBiasType.NO_BIAS, None, id='NO_BIAS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_1HSS, id='POST_SCALE_BIAS-1HSS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_B1SS, id='POST_SCALE_BIAS-B1SS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_BHSS, id='POST_SCALE_BIAS-BHSS'),
pytest.param(AttnBiasType.POST_SCALE_BIAS, BiasShape.BIAS_11SS, id='POST_SCALE_BIAS-11SS'),
])
@pytest.mark.parametrize('attn_mask_type', [
pytest.param(AttnMaskType.NO_MASK, id='NO_MASK'),
Expand All @@ -399,31 +400,32 @@ def check_dqkv(primitive, reference, valid_len):
])
@pytest.mark.parametrize('dtype', [
pytest.param(jnp.bfloat16, id="BF16"),
pytest.param(jnp.float16, id="FP16")
pytest.param(jnp.float16, id="FP16"),
])
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d',[
pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'),
pytest.param( 4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'),
pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'),
pytest.param( 4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'),
pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'),
pytest.param( 4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA')
@pytest.mark.parametrize('b, s_q, s_kv, h_q, h_kv, d', [
pytest.param(32, 128, 128, 16, 16, 64, id='32-128-128-16-16-64-SELF'),
pytest.param(4, 2048, 2048, 12, 12, 64, id='4-2048-2048-12-12-64-SELF'),
pytest.param(32, 512, 128, 16, 16, 64, id='32-512-128-16-16-64-CROSS'),
pytest.param(4, 2048, 1024, 12, 12, 64, id='4-2048-1048-12-12-64-CROSS'),
pytest.param(32, 128, 128, 16, 8, 64, id='32-128-128-16-8-64-GQA'),
pytest.param(4, 2048, 2048, 12, 6, 64, id='4-2048-2048-12-6-64-GQA'),
])
@pytest.mark.parametrize('dropout_prob', [
pytest.param(0.0, id="DROP_0.0"),
pytest.param(0.1, id="DROP_0.1")
])
@pytest.mark.parametrize('is_training', [
pytest.param(True, id='TRAINING'),
pytest.param(False, id='INFERENCE'),
pytest.param(0.1, id="DROP_0.1"),
])
class TestFusedAttn:
"""
Fused attention tester
"""

@staticmethod
def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, is_training, qkv_layout, bias_shape):
@pytest.mark.parametrize('is_training', [
pytest.param(True, id='TRAINING'),
pytest.param(False, id='INFERENCE'),
])
def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
dtype, is_training, qkv_layout, bias_shape):
"""
Test forward with parameterized configs
"""
Expand All @@ -432,13 +434,11 @@ def test_forward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
runner.test_forward()

@staticmethod
def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, is_training, qkv_layout, bias_shape):
def test_backward(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type, dropout_prob,
dtype, qkv_layout, bias_shape):
"""
Test backward with parameterized configs
"""
if not is_training:
pytest.skip("Backward pass does not support inference.")
runner = FusedAttnRunner(b, s_q, s_kv, h_q, h_kv, d, attn_bias_type, attn_mask_type,
dropout_prob, dtype, True, qkv_layout, bias_shape)
runner.test_backward()
2 changes: 2 additions & 0 deletions tests/jax/test_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,6 +449,7 @@ def forward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER,
self_attn_mask_type='padding_causal',
dtype=dtype,
**te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
Expand Down Expand Up @@ -497,6 +498,7 @@ def forward_backward_runner(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-
hidden_dropout_dims=(sequence_dim,),
intermediate_dropout_dims=(sequence_dim,),
layer_type=TransformerLayerType.DECODER,
self_attn_mask_type='padding_causal',
dtype=dtype,
**te_layer_attrs)
ref_layer, ref_params, ref_others = generate_layer(ref_layer_cls, init_rng, inputs,
Expand Down
33 changes: 24 additions & 9 deletions tests/jax/test_praxis_layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -730,8 +730,13 @@ class TestDotProductAttn(TestLayer):
def input_getter(self, shape, dtype):
key = jax.random.PRNGKey(seed=1234)
q_key, k_key, v_key = jax.random.split(key, 3)
return list(map(partial(jax.random.normal, shape=shape, dtype=dtype),
[q_key, k_key, v_key]))
b, s, *_ = shape
if self.attrs[DotProductAttnAttr.TRANSPOSE_BS]:
b, s = s, b
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, k_key, v_key]), mask
]

def get_layer_name(self):
return 'dot_product_attn'
Expand Down Expand Up @@ -765,6 +770,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', DotProductAttnAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)

Expand Down Expand Up @@ -853,9 +859,11 @@ class MultiHeadAttnAttr:
class TestMultiHeadAttn(TestLayer):

def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape,
dtype), jax.random.normal(data_key, shape, dtype))
key = jax.random.PRNGKey(seed=1234)
q_key, kv_key = jax.random.split(key, 2)
s, b, *_ = shape
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask]

def get_layer_name(self):
return 'multi_head_attn'
Expand Down Expand Up @@ -1183,9 +1191,15 @@ class TransformerLayerAttr:
class TestTransformer(TestLayer):

def input_getter(self, shape, dtype):
data_key = jax.random.PRNGKey(seed=1234)
return (jax.random.normal(data_key, shape,
dtype), jax.random.normal(data_key, shape, dtype))
key = jax.random.PRNGKey(seed=1234)
q_key, kv_key = jax.random.split(key, 2)
b, s, *_ = shape
if self.attrs[TransformerLayerAttr.TRANSPOSE_BS]:
b, s = s, b
mask = jnp.zeros((b, 1, s, s), dtype=jnp.uint8)
return [
*map(partial(jax.random.normal, shape=shape, dtype=dtype), [q_key, kv_key]), mask, mask
]

def get_layer_name(self):
return 'transformerlayer'
Expand Down Expand Up @@ -1277,6 +1291,7 @@ def generate_praxis_p_and_flax_cls(self, dtype, attrs):
@pytest.mark.parametrize('dtype', DTYPE)
@pytest.mark.parametrize('attrs', TransformerLayerAttr.ATTRS)
def test_forward_backward(self, data_shape, dtype, attrs, rtol=1e-05, atol=1e-08):
self.attrs = attrs
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
self.forward_backward_runner(data_shape, dtype, praxis_p, flax_cls, rtol, atol)

Expand All @@ -1292,7 +1307,7 @@ def test_forward_backward_fp8(self,
fp8_format,
rtol=1e-05,
atol=1e-08):

self.attrs = attrs
ds = DelayedScaling(fp8_format=fp8_format)
with fp8_autocast(enabled=True, fp8_recipe=ds):
praxis_p, flax_cls = self.generate_praxis_p_and_flax_cls(dtype, attrs)
Expand Down
Loading

0 comments on commit 4455c0a

Please sign in to comment.