From dac0001911139b70f51f3db14ef2c1d96d6161d2 Mon Sep 17 00:00:00 2001 From: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Date: Tue, 23 Apr 2024 17:08:24 -0700 Subject: [PATCH] [JAX] Unifying GeLU and GeGLU in LayerNorm MLP (#765) * combined layernorm_geglu with layernorm_gelu into fused_layernorm Signed-off-by: Phuong Nguyen * fixes to pass all unit tests in test_custom_call_compute.py, test_layer.py, and test_praxis_layer.py Signed-off-by: Phuong Nguyen * cleaning and formatting Signed-off-by: Phuong Nguyen * renaming based on reviewers suggestions Signed-off-by: Phuong Nguyen * implemented partial fused layernorm Signed-off-by: Phuong Nguyen * geglu + bias passed tests Signed-off-by: Phuong Nguyen * added partial fused calculation for dbias_1 Signed-off-by: Phuong Nguyen * clean up Co-authored-by: Alp Dener Signed-off-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> --------- Signed-off-by: Phuong Nguyen Signed-off-by: Phuong Nguyen <36155692+phu0ngng@users.noreply.github.com> Co-authored-by: Alp Dener --- tests/jax/test_custom_call_compute.py | 189 ++---- .../common/transpose/cast_transpose_fusion.cu | 9 +- transformer_engine/jax/cpp_extensions.py | 225 +++++++ transformer_engine/jax/csrc/extensions.cpp | 2 + transformer_engine/jax/csrc/modules.cpp | 63 ++ transformer_engine/jax/csrc/modules.h | 6 + transformer_engine/jax/flax/module.py | 105 ++-- transformer_engine/jax/mlp.py | 593 ++++++------------ 8 files changed, 575 insertions(+), 617 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 8aa6c399f4..139ef994fa 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -4,6 +4,7 @@ import functools import operator +from typing import Callable, Sequence, Union import jax import jax.numpy as jnp @@ -22,8 +23,7 @@ from transformer_engine.jax.fp8 import FP8MetaPackage, FP8Helper from transformer_engine.jax.fp8 import is_fp8_available from transformer_engine.jax.layernorm import layernorm -from transformer_engine.jax.mlp import layernorm_geglu_fp8_mlp -from transformer_engine.jax.mlp import layernorm_gelu_fp8_mlp +from transformer_engine.jax.mlp import fused_layernorm_fp8_mlp GEMM_CASES = [ (256, 256, 512), @@ -174,17 +174,32 @@ def ref_func(x, y): assert_allclose(primitive_b_grad, ref_b_grad, dtype=FP8Helper.BWD_DTYPE) @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024), + @pytest.mark.parametrize('m,n,k', [(256, 512, 128), (16384, 1024, 2816), (16384, 2816, 1024), (16384, 1024, 1024)]) - def test_grad_ln_geglu_fp8_mlp(self, m, n, k): + @pytest.mark.parametrize('activation_type', [('gelu', ), + ('gelu', 'linear')]) + @pytest.mark.parametrize('use_bias', [True, False]) + def test_grad_fused_layernorm_fp8_mlp(self, m, n, k, + activation_type: Sequence[Union[str, Callable]], + use_bias: bool): + """ N/a """ key = jax.random.PRNGKey(0) - subkeys = jax.random.split(key, 4) - activations = ('gelu', 'linear') + subkeys = jax.random.split(key, 6) + + activation_dict = { + ('gelu', ): jax.nn.gelu + } a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16) - k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16) + k1 = jax.random.normal(subkeys[1], (k, len(activation_type), n), jnp.bfloat16) k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) - s = jax.random.normal(subkeys[3], (k,), jnp.bfloat16) + s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16) + if use_bias: + b1 = jax.random.normal(subkeys[3], (len(activation_type), n), jnp.bfloat16) + b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16) + else: + b1 = jax.random.normal(subkeys[3], (0,), jnp.bfloat16) + b2 = jax.random.normal(subkeys[4], (0,), jnp.bfloat16) init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2) init_fp8_metas_amax = jnp.zeros( @@ -192,14 +207,16 @@ def test_grad_ln_geglu_fp8_mlp(self, m, n, k): init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32) init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32) - def primitive_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale, + def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv): # x is input tensor, matrix 2d # y, z are weights, matrix 2d - # out = (x * y) * z + # out = ((x * y) + w) * z + v fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv) - return jnp.mean(layernorm_geglu_fp8_mlp(x, ln_s, None, [y, z], fp8_meta_pkg, "rmsnorm")) + return jnp.mean( + fused_layernorm_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm", + activation_type = activation_type, use_bias = use_bias)) def _convert_to_activation_function(fn_or_string): """Convert a string to an activation function.""" @@ -211,115 +228,7 @@ def _convert_to_activation_function(fn_or_string): return fn_or_string raise ValueError(f"don't know how to convert {fn_or_string} to an activation function") - def ln_geglu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray, - kernel_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray, - scale: jnp.ndarray, scale_inv: jnp.ndarray) -> jnp.ndarray: - - x = jnp.asarray(x, jnp.float32) - mean2 = jnp.mean(jax.lax.square(x), axis=-1, keepdims=True) - y = jnp.asarray(x * jax.lax.rsqrt(mean2 + 1e-6), jnp.bfloat16) - ln_out = y * ln_scale - ln_out = jnp.asarray(ln_out, jnp.bfloat16) - - fp8_gemm_1_pkg = FP8MetaPackage(1, fp8_maxs[:FP8Helper.NUM_META_PER_GEMM], - amax[:FP8Helper.NUM_META_PER_GEMM], - scale[:FP8Helper.NUM_META_PER_GEMM], - scale_inv[:FP8Helper.NUM_META_PER_GEMM]) - linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,))) - - x = jnp.split(linear_1_out, len(activations), axis=-2) - acts = [] - for idx, act_fn in enumerate(activations): - x_i = _convert_to_activation_function(act_fn)(x[idx]) - acts.append(x_i) - x = functools.reduce(operator.mul, acts) - x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16) - - fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:], - amax[FP8Helper.NUM_META_PER_GEMM:], - scale[FP8Helper.NUM_META_PER_GEMM:], - scale_inv[FP8Helper.NUM_META_PER_GEMM:]) - output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,))) - return output - - def ref_func(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv): - return jnp.mean( - ln_geglu_fp8_mlp_ref(x, ln_s, y, z, fp8_max, fp8_metas_amax, fp8_metas_scale, - fp8_metas_scale_inv)) - - value_n_grad_primitive_func = jit(value_and_grad(primitive_func, (0, 1, 2, 3, 4, 5, 6, 7))) - value_n_grad_ref_func = jit(value_and_grad(ref_func, (0, 1, 2, 3, 4, 5, 6, 7))) - - ref_fp8_max = init_fp8_max - ref_fp8_metas_amax = init_fp8_metas_amax - ref_fp8_metas_scale = init_fp8_metas_scale - ref_fp8_metas_scale_inv = init_fp8_metas_scale_inv - - pri_fp8_max = init_fp8_max - pri_fp8_metas_amax = init_fp8_metas_amax - pri_fp8_metas_scale = init_fp8_metas_scale - pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv - - for _ in range(3): - ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_fp8_max, - ref_fp8_metas_amax, ref_fp8_metas_scale, - ref_fp8_metas_scale_inv) = value_n_grad_ref_func( - a, s, k1, k2, ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale, - ref_fp8_metas_scale_inv) - - for _ in range(3): - primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad, - primitive_k2_grad, pri_fp8_max, pri_fp8_metas_amax, pri_fp8_metas_scale, - pri_fp8_metas_scale_inv) = value_n_grad_primitive_func( - a, s, k1, k2, pri_fp8_max, pri_fp8_metas_amax, pri_fp8_metas_scale, - pri_fp8_metas_scale_inv) - - assert_allclose(primitive_out, ref_out, dtype=FP8Helper.FWD_DTYPE) - assert_allclose(jnp.asarray(primitive_a_grad, np.float32), - jnp.asarray(ref_a_grad, np.float32), - dtype=FP8Helper.BWD_DTYPE) - assert_allclose(jnp.asarray(primitive_k1_grad, np.float32), - jnp.asarray(ref_k1_grad, np.float32), - dtype=FP8Helper.BWD_DTYPE) - assert_allclose(jnp.asarray(primitive_k2_grad, np.float32), - jnp.asarray(ref_k2_grad, np.float32), - dtype=FP8Helper.BWD_DTYPE) - assert_allclose(jnp.asarray(primitive_s_grad, np.float32), - jnp.asarray(ref_s_grad, np.float32), - dtype=FP8Helper.BWD_DTYPE) - - @pytest.mark.skipif(not is_fp8_supported, reason=reason) - @pytest.mark.parametrize('m,n,k', [(256, 256, 512), (16384, 1024, 2816), (16384, 2816, 1024), - (16384, 1024, 1024)]) - def test_grad_ln_gelu_fp8_mlp(self, m, n, k): - key = jax.random.PRNGKey(0) - subkeys = jax.random.split(key, 6) - activations = ('gelu',) - - a = jax.random.normal(subkeys[0], (m, k), jnp.bfloat16) - k1 = jax.random.normal(subkeys[1], (k, len(activations), n), jnp.bfloat16) - k2 = jax.random.normal(subkeys[2], (n, k), jnp.bfloat16) - b1 = jax.random.normal(subkeys[3], (len(activations), n), jnp.bfloat16) - b2 = jax.random.normal(subkeys[4], (k,), jnp.bfloat16) - s = jax.random.normal(subkeys[5], (k,), jnp.bfloat16) - - init_fp8_max = FP8Helper.generate_fp8_max_array(FP8Helper.NUM_META_PER_GEMM * 2) - init_fp8_metas_amax = jnp.zeros( - (FP8Helper.NUM_META_PER_GEMM * 2, FP8Helper.AMAX_HISTORY_LEN), jnp.float32) - init_fp8_metas_scale = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32) - init_fp8_metas_scale_inv = jnp.ones((FP8Helper.NUM_META_PER_GEMM * 2, 1), jnp.float32) - - def primitive_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale, - fp8_metas_scale_inv): - # x is input tensor, matrix 2d - # y, z are weights, matrix 2d - # out = ((x * y) + w) * z + v - fp8_meta_pkg = FP8MetaPackage(2, fp8_max, fp8_metas_amax, fp8_metas_scale, - fp8_metas_scale_inv) - return jnp.mean( - layernorm_gelu_fp8_mlp(x, ln_s, None, [y, z], [w, v], fp8_meta_pkg, "rmsnorm")) - - def ln_gelu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray, + def layernorm_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray, fp8_maxs: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray) -> jnp.ndarray: @@ -336,10 +245,20 @@ def ln_gelu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.nda scale_inv[:FP8Helper.NUM_META_PER_GEMM]) linear_1_out = type_safe_dot_general(ln_out, kernel_1, fp8_gemm_1_pkg, ((1,), (0,))) - bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape - linear_1_out += jnp.reshape(bias_1, bias_1_shape) + if use_bias: + bias_1_shape = (1,) * (linear_1_out.ndim - bias_1.ndim) + bias_1.shape + linear_1_out += jnp.reshape(bias_1, bias_1_shape) + + if 'linear' in activation_type: + x = jnp.split(linear_1_out, len(activation_type), axis=-2) + acts = [] + for idx, act_fn in enumerate(activation_type): + x_i = _convert_to_activation_function(act_fn)(x[idx]) + acts.append(x_i) + x = functools.reduce(operator.mul, acts) + else: + x = activation_dict[activation_type](linear_1_out) - x = jax.nn.gelu(linear_1_out) x = jnp.asarray(jnp.squeeze(x, axis=-2), jnp.bfloat16) fp8_gemm_2_pkg = FP8MetaPackage(1, fp8_maxs[FP8Helper.NUM_META_PER_GEMM:], @@ -348,15 +267,16 @@ def ln_gelu_fp8_mlp_ref(x: jnp.ndarray, ln_scale: jnp.ndarray, kernel_1: jnp.nda scale_inv[FP8Helper.NUM_META_PER_GEMM:]) output = type_safe_dot_general(x, kernel_2, fp8_gemm_2_pkg, ((1,), (0,))) - bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape - output += jnp.reshape(bias_2, bias_2_shape) + if use_bias: + bias_2_shape = (1,) * (output.ndim - bias_2.ndim) + bias_2.shape + output += jnp.reshape(bias_2, bias_2_shape) return output def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv): return jnp.mean( - ln_gelu_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale, + layernorm_fp8_mlp_ref(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale, fp8_metas_scale_inv)) value_n_grad_primitive_func = jit( @@ -373,12 +293,13 @@ def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale, pri_fp8_metas_scale = init_fp8_metas_scale pri_fp8_metas_scale_inv = init_fp8_metas_scale_inv + # Convert str to index as str is not a valid type for JAX JIT for _ in range(3): ref_out, (ref_a_grad, ref_s_grad, ref_k1_grad, ref_k2_grad, ref_b1_grad, ref_b2_grad, ref_fp8_max, ref_fp8_metas_amax, ref_fp8_metas_scale, ref_fp8_metas_scale_inv) = value_n_grad_ref_func( a, s, k1, k2, b1, b2, ref_fp8_max, ref_fp8_metas_amax, - ref_fp8_metas_scale, ref_fp8_metas_scale_inv) + ref_fp8_metas_scale, ref_fp8_metas_scale_inv) for _ in range(3): primitive_out, (primitive_a_grad, primitive_s_grad, primitive_k1_grad, @@ -401,12 +322,14 @@ def ref_func(x, ln_s, y, z, w, v, fp8_max, fp8_metas_amax, fp8_metas_scale, assert_allclose(jnp.asarray(primitive_s_grad, np.float32), jnp.asarray(ref_s_grad, np.float32), dtype=FP8Helper.BWD_DTYPE) - assert_allclose(jnp.asarray(primitive_b1_grad, np.float32), - jnp.asarray(ref_b1_grad, np.float32), - dtype=jnp.bfloat16) - assert_allclose(jnp.asarray(primitive_b2_grad, np.float32), - jnp.asarray(ref_b2_grad, np.float32), - dtype=jnp.bfloat16) + if use_bias: + assert_allclose(jnp.asarray(primitive_b1_grad, np.float32), + jnp.asarray(ref_b1_grad, np.float32), + dtype=jnp.bfloat16) + assert_allclose(jnp.asarray(primitive_b2_grad, np.float32), + jnp.asarray(ref_b2_grad, np.float32), + dtype=jnp.bfloat16) + @pytest.fixture(name="random_inputs") diff --git a/transformer_engine/common/transpose/cast_transpose_fusion.cu b/transformer_engine/common/transpose/cast_transpose_fusion.cu index 434f2651d3..8e455dddb5 100644 --- a/transformer_engine/common/transpose/cast_transpose_fusion.cu +++ b/transformer_engine/common/transpose/cast_transpose_fusion.cu @@ -529,10 +529,11 @@ void cast_transpose_dbias(const Tensor &input, Tensor *dbias, Tensor *workspace, cudaStream_t stream) { - CheckInputTensor(input, "cast_transpose_dbias_input"); - CheckOutputTensor(*cast_output, "cast_output"); - CheckOutputTensor(*transposed_output, "transposed_output"); - CheckOutputTensor(*dbias, "dbias"); + // TODO + // CheckInputTensor(input, "cast_transpose_dbias_input"); + // CheckOutputTensor(*cast_output, "cast_output"); + // CheckOutputTensor(*transposed_output, "transposed_output"); + // CheckOutputTensor(*dbias, "dbias"); NVTE_CHECK(input.data.shape.size() == 2, "Input must have 2 dimensions."); NVTE_CHECK(cast_output->data.shape.size() == 2, "C output must have 2 dimensions."); diff --git a/transformer_engine/jax/cpp_extensions.py b/transformer_engine/jax/cpp_extensions.py index 3356aafef5..adcd5770e2 100644 --- a/transformer_engine/jax/cpp_extensions.py +++ b/transformer_engine/jax/cpp_extensions.py @@ -4334,6 +4334,231 @@ def dgelu_dbias_cast_transpose( transpose_axis_boundary=transpose_axis_boundary) +class DBiasCastTransposePrimitive(BasePrimitive): + """ + DBias Cast Transpose Primitive + """ + name = "te_dbias_cast_transpose" + multiple_results = True + # out_dtype, static_axis_boundary, transpose_axis_boundary + impl_static_args = (4, 5, 6) + inner_primitive = None + outer_primitive = None + + @staticmethod + def abstract(dz_aval, amax_aval, scale_aval, scale_inv_aval, *, out_dtype, + static_axis_boundary, transpose_axis_boundary): + """ + te_dbias_cast_transpose_p abstract + """ + dtype = dtypes.canonicalize_dtype(dz_aval.dtype) + assert dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + gi_hidden_size = dz_aval.shape[-1] + t_shape = _multidim_transpose(dz_aval.shape, static_axis_boundary, transpose_axis_boundary) + out = dz_aval.update(shape=dz_aval.shape, dtype=out_dtype) + t_out = dz_aval.update(shape=t_shape, dtype=out_dtype) + + if dz_aval.shape[-2] == 2: + gi_hidden_size *= 2 + dbias_shape = (*dz_aval.shape[:static_axis_boundary + 1], gi_hidden_size) + dbias = dz_aval.update(shape=dbias_shape, dtype=dtype) + + updated_amax_aval = amax_aval.update(shape=amax_aval.shape, dtype=amax_aval.dtype) + wkspace_info, = transformer_engine_jax.get_dbias_ct_workspace_sizes( + dz_aval.size // gi_hidden_size, + gi_hidden_size, + jax_dtype_to_te_dtype(dz_aval.dtype), + jax_dtype_to_te_dtype(out_dtype) + ) + wkspace_aval = dz_aval.update(shape=wkspace_info[0], + dtype=te_dtype_to_jax_dtype(wkspace_info[1])) + + return out, t_out, dbias, updated_amax_aval, wkspace_aval + + @staticmethod + def outer_abstract(*args, **kwargs): + """ + te_dbias_cast_transpose_p outer abstract + """ + + out, t_out, dbias, updated_amax_aval, _ = \ + DBiasCastTransposePrimitive.abstract(*args, **kwargs) + return out, t_out, dbias, updated_amax_aval + + @staticmethod + def lowering(ctx, dz, amax, scale, scale_inv, *, out_dtype, static_axis_boundary, + transpose_axis_boundary): + """ + te_dbias_cast_transpose_p lowering rules + """ + dz_aval, amax_aval, scale_aval, scale_inv_aval = ctx.avals_in + assert dz_aval.dtype in [jnp.float32, jnp.float16, jnp.bfloat16] + assert amax_aval.dtype == jnp.float32 + assert scale_aval.dtype == jnp.float32 + assert scale_inv_aval.dtype == jnp.float32 + ir_dz_type = ir.RankedTensorType(dz.type) + ir_dz_shape = ir_dz_type.shape + ir_hidden_szie = ir_dz_shape[-1] + if dz_aval.shape[-2] == 2: + batch_szie = reduce(operator.mul, ir_dz_shape[:-2]) + ir_hidden_szie *= 2 + else: + batch_szie = reduce(operator.mul, ir_dz_shape[:-1]) + contracted_dz_shape = (batch_szie, ir_hidden_szie) + ir_out_dtype = jax_dtype_to_ir_dtype(out_dtype) + ir_amax_type = ir.RankedTensorType(amax.type) + ir_amax_dtype = ir_amax_type.element_type + ir_amax_shape = ir_amax_type.shape + ir_scale_shape = ir_amax_shape + ir_scale_inv_shape = ir_amax_shape + transposed_dz_shape = _multidim_transpose(ir_dz_shape, static_axis_boundary, + transpose_axis_boundary) + dbias_shape = (*ir_dz_shape[:static_axis_boundary + 1], ir_hidden_szie) + + wkspace_aval = ctx.avals_out[-1] + + out_types = [ + ir.RankedTensorType.get(ir_dz_shape, ir_out_dtype), + ir.RankedTensorType.get(transposed_dz_shape, ir_out_dtype), + ir.RankedTensorType.get(dbias_shape, ir_dz_type.element_type), + ir.RankedTensorType.get(ir_amax_shape, ir_amax_dtype), + ir.RankedTensorType.get(wkspace_aval.shape, jax_dtype_to_ir_dtype(wkspace_aval.dtype)), + ] + operands = [dz, amax, scale, scale_inv] + operand_shapes = [ir_dz_shape, ir_amax_shape, ir_scale_shape, ir_scale_inv_shape] + args = CustomCallArgsWrapper(out_types, operands, operand_shapes) + opaque = transformer_engine_jax.pack_common_wk_descriptor( + contracted_dz_shape, wkspace_aval.shape, jax_dtype_to_te_dtype(dz_aval.dtype), + jax_dtype_to_te_dtype(out_dtype), jax_dtype_to_te_dtype(wkspace_aval.dtype)) + + out = custom_caller(DBiasCastTransposePrimitive.name, + args, + opaque, + False, + operand_output_aliases={1: 3}) + + return out + + @staticmethod + def impl(dz, amax, scale, scale_inv, out_dtype, static_axis_boundary, + transpose_axis_boundary): + """ + to describe implementation + """ + assert DBiasCastTransposePrimitive.inner_primitive is not None + out, t_out, dbias, updated_amax, _ = DBiasCastTransposePrimitive.inner_primitive.bind( + dz, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary) + return out, t_out, dbias, updated_amax + + @staticmethod + def batcher(batched_args, batch_dims, *, out_dtype, static_axis_boundary, + transpose_axis_boundary): + """ + to describe batch rules for vmap + """ + del static_axis_boundary + _check_valid_batch_dims(batch_dims) + assert DBiasCastTransposePrimitive.outer_primitive is not None + dz, amax, scale, scale_inv = batched_args + dz_bdim, _, amax_bdim, _, _ = batch_dims + + # Minus batch dim. + transpose_axis_boundary = _normalize_axis_boundary(transpose_axis_boundary, dz.ndim - 1) + transpose_axis_boundary += 1 # Plus batch dim + + out_bdims = dz_bdim, dz_bdim, dz_bdim, amax_bdim + return DBiasCastTransposePrimitive.outer_primitive.bind( + dz, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=dz_bdim, + transpose_axis_boundary=transpose_axis_boundary), out_bdims + + @staticmethod + def infer_sharding_from_operands(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, + arg_infos, result_infos): + del out_dtype, result_infos + x_spec = get_padded_spec(arg_infos[1]) + out_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) + tranposed_out_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + dbias_shaprding = NamedSharding( + mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1])) + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) + return (out_sharding, tranposed_out_sharding, dbias_shaprding, amax_sharding) + + @staticmethod + def partition(out_dtype, static_axis_boundary, transpose_axis_boundary, mesh, arg_infos, + result_infos): + del result_infos + x_spec = get_padded_spec(arg_infos[1]) + casted_x_sharding = NamedSharding(mesh, PartitionSpec(*x_spec)) + xt_spec = _multidim_transpose(x_spec, static_axis_boundary, transpose_axis_boundary) + casted_transposed_x_sharding = NamedSharding(mesh, PartitionSpec(*xt_spec)) + + dbias_shaprding = NamedSharding( + mesh, PartitionSpec(*x_spec[:static_axis_boundary + 1], x_spec[-1])) + + amax_sharding = NamedSharding(mesh, PartitionSpec(*get_padded_spec(arg_infos[2]))) + arg_shardings = tuple(arg_i.sharding for arg_i in arg_infos) + out_shardings = (casted_x_sharding, casted_transposed_x_sharding, dbias_shaprding, + amax_sharding) + + def sharded_impl(dz, amax, scale, scale_inv): + local_out, local_t_out, local_dbias, local_amax = DBiasCastTransposePrimitive.impl( + dz, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary) + global_dbias = all_reduce_sum_along_dp_fsdp(local_dbias) + global_updated_amax = all_reduce_max_along_all_axes_except_PP(local_amax) + return local_out, local_t_out, global_dbias, global_updated_amax + + return mesh, sharded_impl, out_shardings, arg_shardings + + +register_primitive(DBiasCastTransposePrimitive) + + +def dbias_cast_transpose( + dz: jnp.ndarray, + amax: jnp.ndarray, + scale: jnp.ndarray, + scale_inv: jnp.ndarray, + out_dtype: TEDType, + static_axis_boundary: int, + transpose_axis_boundary: int = -1) -> Tuple[jnp.ndarray, jnp.ndarray, jnp.ndarray]: + """ + cast transpose dbias partial fusion wrapper + Return FP8(inputs), dbias + """ + if static_axis_boundary < 0: + static_axis_boundary = -1 # means no static axes + + return DBiasCastTransposePrimitive.outer_primitive.bind( + dz, + amax, + scale, + scale_inv, + out_dtype=out_dtype, + static_axis_boundary=static_axis_boundary, + transpose_axis_boundary=transpose_axis_boundary) + + class GatedGeluFp8Primitive(BasePrimitive): """ Gated Gelu FP8 Primitive diff --git a/transformer_engine/jax/csrc/extensions.cpp b/transformer_engine/jax/csrc/extensions.cpp index 5e4ab4f205..8aa6b492c8 100644 --- a/transformer_engine/jax/csrc/extensions.cpp +++ b/transformer_engine/jax/csrc/extensions.cpp @@ -29,6 +29,7 @@ pybind11::dict Registrations() { dict["te_gelu_fp8"] = EncapsulateFunction(GeluFP8); dict["te_dgelu"] = EncapsulateFunction(DGelu); dict["te_dgelu_dbias_cast_transpose"] = EncapsulateFunction(DGeluDBiasCastTranspose); + dict["te_dbias_cast_transpose"] = EncapsulateFunction(DBiasCastTranspose); dict["te_gated_gelu"] = EncapsulateFunction(GatedGelu); dict["te_gated_gelu_fp8"] = EncapsulateFunction(GatedGeluFP8); dict["te_dgated_gelu"] = EncapsulateFunction(DGatedGelu); @@ -66,6 +67,7 @@ PYBIND11_MODULE(transformer_engine_jax, m) { m.def("get_device_compute_capability", &GetDeviceComputeCapability); m.def("get_cublasLt_version", &cublasLtGetVersion); m.def("get_dgelu_dbias_ct_workspace_sizes", &GetDGeluDBiasCastTransposeWorkspaceSizes); + m.def("get_dbias_ct_workspace_sizes", &GetDBiasCastTransposeWorkspaceSizes); m.def("get_layernorm_fwd_workspace_sizes", &GetLayerNormForwardWorkspaceSizes); m.def("get_layernorm_bwd_workspace_sizes", &GetLayerNormBackwardWorkspaceSizes); m.def("get_fused_attn_fwd_workspace_sizes", &GetFusedAttnForwardWorkspaceSizes); diff --git a/transformer_engine/jax/csrc/modules.cpp b/transformer_engine/jax/csrc/modules.cpp index 4ac6fa58b1..48b02bcaeb 100644 --- a/transformer_engine/jax/csrc/modules.cpp +++ b/transformer_engine/jax/csrc/modules.cpp @@ -301,6 +301,69 @@ void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *op dbias_tensor.data(), workspace.data(), stream); } +// HERE +pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, + DType in_dtype, DType out_dtype) { + auto input_shape = std::vector{batch_size, hidden_size}; + auto output_shape = std::vector{batch_size, hidden_size}; + auto output_trans_shape = std::vector{hidden_size, batch_size}; + auto dbias_shape = std::vector{hidden_size}; + + auto input_tensor = TensorWrapper(nullptr, input_shape, in_dtype); + auto output_tensor = TensorWrapper(nullptr, output_shape, out_dtype); + auto output_trans_tensor = TensorWrapper(nullptr, output_trans_shape, out_dtype); + auto dbias_tensor = TensorWrapper(nullptr, dbias_shape, in_dtype); + + TensorWrapper dummy_workspace; + + nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), + output_trans_tensor.data(), dbias_tensor.data(), + dummy_workspace.data(), nullptr); + + auto work_shape = MakeShapeVector(dummy_workspace.shape()); + return pybind11::make_tuple(std::make_pair(work_shape, dummy_workspace.dtype())); +} + +void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, + size_t opaque_len) { + auto *input = buffers[0]; + float *amax = reinterpret_cast(buffers[1]); + float *scale = reinterpret_cast(buffers[2]); + float *scale_inv = reinterpret_cast(buffers[3]); + auto *output = buffers[4]; + auto *output_trans = buffers[5]; + auto *dbias = buffers[6]; + float *amax_out = reinterpret_cast(buffers[7]); + void *workspace_ptr = buffers[8]; + + const auto &desc = *UnpackOpaque(opaque, opaque_len); + assert(amax == amax_out); + if (!use_fp8(desc.out_dtype)) { + scale = nullptr; + scale_inv = nullptr; + amax_out = nullptr; + } + auto m = desc.shape.dims[0]; + auto n = desc.shape.dims[1]; + auto input_shape = std::vector{m, n}; + auto output_shape = std::vector{m, n}; + auto output_trans_shape = std::vector{n, m}; + auto dbias_shape = std::vector{n}; + + auto input_tensor = TensorWrapper(input, input_shape, desc.in_dtype); + auto output_tensor = + TensorWrapper(output, output_shape, desc.out_dtype, amax_out, scale, scale_inv); + auto output_trans_tensor = + TensorWrapper(output_trans, output_trans_shape, desc.out_dtype, amax_out, scale, scale_inv); + auto dbias_tensor = TensorWrapper(dbias, dbias_shape, desc.in_dtype); + + auto workspace = TensorWrapper(workspace_ptr, desc.wkshape.to_vector(), desc.wk_dtype); + + nvte_cast_transpose_dbias(input_tensor.data(), output_tensor.data(), + output_trans_tensor.data(), dbias_tensor.data(), + workspace.data(), stream); +} + void GatedGeluImpl(void *input, size_t m, size_t n, DType in_dtype, DType out_dtype, float *scale, cudaStream_t stream, float *scale_inverse, float *amax, void *output) { auto input_shape = std::vector{m, n * 2}; diff --git a/transformer_engine/jax/csrc/modules.h b/transformer_engine/jax/csrc/modules.h index 04f0039b02..4285c8228e 100644 --- a/transformer_engine/jax/csrc/modules.h +++ b/transformer_engine/jax/csrc/modules.h @@ -152,6 +152,12 @@ pybind11::tuple GetDGeluDBiasCastTransposeWorkspaceSizes(size_t batch_size, size void DGeluDBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); +pybind11::tuple GetDBiasCastTransposeWorkspaceSizes(size_t batch_size, size_t hidden_size, + DType in_dtype, DType out_dtype); + +void DBiasCastTranspose(cudaStream_t stream, void **buffers, const char *opaque, + size_t opaque_len); + void GatedGelu(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); void GatedGeluFP8(cudaStream_t stream, void **buffers, const char *opaque, size_t opaque_len); diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 8ca8edcb0b..36008cf854 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -22,8 +22,7 @@ from ..fp8 import FP8Helper, FP8MetaPackage from ..layernorm import canonicalize_layernorm_type from ..layernorm import layernorm, layernorm_fp8_dot -from ..mlp import layernorm_geglu_fp8_mlp, geglu -from ..mlp import layernorm_gelu_fp8_mlp, gelu +from ..mlp import fused_layernorm_fp8_mlp, activation_lu from ..softmax import is_softmax_kernel_available from ..softmax import softmax, SoftmaxType from ..sharding import with_sharding_constraint_by_logical_axes @@ -944,35 +943,22 @@ def __call__(self, inputs: Array, deterministic: bool = False) -> Array: fuse_layernorm = FP8Helper.is_fp8_enabled( ) and not self.return_layernorm_output and self.enable_layernorm - def is_geglu(acts): - geglu_act_pool = [('gelu', 'linear'), ('linear', 'gelu')] - - normalize_acts = [] - for act in acts: - if not isinstance(act, str): - return False - normalize_acts.append(act.lower()) - return tuple(normalize_acts) in geglu_act_pool - - def is_gelu(acts): - geglu_act_pool = [('gelu',)] - - normalize_acts = [] - for act in acts: - if not isinstance(act, str): - return False - normalize_acts.append(act.lower()) - return tuple(normalize_acts) in geglu_act_pool - - use_fused_ln_geglu_mlp = fuse_layernorm \ - and (not self.use_bias) and is_geglu(self.activations) \ - and (self.intermediate_dropout_rate < 1e-3) \ - and not self.enable_low_rank_adaptation - - use_fused_ln_gelu_mlp = fuse_layernorm \ - and self.use_bias and is_gelu(self.activations) \ - and (self.intermediate_dropout_rate < 1e-3) \ - and not self.enable_low_rank_adaptation + # Make sure each tuple is sorted in alphabet order + gated_act_pool = [('gelu', 'linear')] + #('linear', 'silu')] coming + act_pool = [('gelu',)] + #('silu',)] coming + normalize_acts = [] + for act in self.activations: + if not isinstance(act, str): + return False + normalize_acts.append(act.lower()) + normalize_acts = tuple(sorted(normalize_acts)) + is_gated = normalize_acts in gated_act_pool + is_act_implemented = normalize_acts in (gated_act_pool + act_pool) + + use_fused_layernorm_mlp = fuse_layernorm and is_act_implemented and\ + self.intermediate_dropout_rate < 1e-3 # LayerNorm if self.enable_layernorm: @@ -1045,38 +1031,26 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): ffn1_ckpt_name = 'ffn1' ffn2_ckpt_name = 'ffn2' - if use_fused_ln_geglu_mlp: - assert self.axis == -1 # Only support axis = =-1 at this moment - - out = layernorm_geglu_fp8_mlp(y, - scale, - ln_bias, [kernel_1, kernel_2], - fp8_meta_package, - self.layernorm_type, - zero_centered_gamma=self.zero_centered_gamma, - epsilon=self.epsilon, - layernorm_input_axes=self.layernorm_input_axes, - dot_1_input_axes=self.dot_1_input_axes, - dot_2_input_axes=self.dot_2_input_axes, - ffn1_ckpt_name=ffn1_ckpt_name, - ffn2_ckpt_name=ffn2_ckpt_name) - elif use_fused_ln_gelu_mlp: + if use_fused_layernorm_mlp: assert self.axis == -1 # Only support axis = =-1 at this moment + bias_1_shape = intermediate_dim if self.use_bias else 0 bias_1 = nn_partitioning.param_with_axes('wi_bias', self.bias_init, - intermediate_dim, + bias_1_shape, jnp.float32, axes=self.bias_axes_1) bias_1 = bias_1.astype(self.dtype) + bias_2_shape = (hidden_size,) if self.use_bias else (0,) bias_2 = nn_partitioning.param_with_axes('wo_bias', - self.bias_init, (hidden_size,), + self.bias_init, + bias_2_shape, jnp.float32, axes=self.bias_axes_2) bias_2 = bias_2.astype(self.dtype) - out = layernorm_gelu_fp8_mlp(y, + out = fused_layernorm_fp8_mlp(y, scale, ln_bias, [kernel_1, kernel_2], [bias_1, bias_2], fp8_meta_package, @@ -1087,9 +1061,10 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): dot_1_input_axes=self.dot_1_input_axes, dot_2_input_axes=self.dot_2_input_axes, ffn1_ckpt_name=ffn1_ckpt_name, - ffn2_ckpt_name=ffn2_ckpt_name) + ffn2_ckpt_name=ffn2_ckpt_name, + activation_type = normalize_acts, + use_bias = self.use_bias) else: # not use_fused_ln_geglu_mlp - # DenseGeneral 1 gemm1_fp8_meta_package = None if fp8_meta_package is None \ else fp8_meta_package.get_package_by_gemm_idx(0) @@ -1142,31 +1117,29 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): x += _apply_low_rank_adaptation(y, axis, intermediate_dim, wi_lora_a_kernel, wi_lora_b_kernel, self.low_rank_adaptation_alpha) - bias = None + bias_1 = None if self.use_bias: - bias = nn_partitioning.param_with_axes('wi_bias', + bias_1 = nn_partitioning.param_with_axes('wi_bias', self.bias_init, intermediate_dim, jnp.float32, axes=self.bias_axes_1) - bias = bias.astype(self.dtype) - bias_shape = (1,) * (x.ndim - bias.ndim) + bias.shape - x += jnp.reshape(bias, bias_shape) + bias_1 = bias_1.astype(self.dtype) + bias_1_shape = (1,) * (x.ndim - bias_1.ndim) + bias_1.shape + x += jnp.reshape(bias_1, bias_1_shape) x = checkpoint_name(x, ffn1_ckpt_name) activations = [] - if is_geglu(self.activations): - z = geglu(x) - elif is_gelu(self.activations): - z = gelu(x) - z = jnp.reshape(z, (*z.shape[:-2], -1)) + if is_act_implemented: + z = activation_lu(x, normalize_acts) else: x = jnp.split(x, num_activations, axis=-2) for idx, act_fn in enumerate(self.activations): x_i = _convert_to_activation_function(act_fn)(x[idx]) activations.append(x_i) z = functools.reduce(operator.mul, activations) + if not is_gated: z = jnp.reshape(z, (*z.shape[:-2], -1)) z = nn.Dropout(rate=self.intermediate_dropout_rate, @@ -1207,14 +1180,14 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): out += _apply_low_rank_adaptation(z, axis, hidden_size_tuple, wo_lora_a_kernel, wo_lora_b_kernel, self.low_rank_adaptation_alpha) - bias = None + bias_2 = None if self.use_bias: - bias = nn_partitioning.param_with_axes('wo_bias', + bias_2 = nn_partitioning.param_with_axes('wo_bias', self.bias_init, (hidden_size,), jnp.float32, axes=self.bias_axes_2) - bias = bias.astype(self.dtype) - out += jnp.reshape(bias, (1,) * (out.ndim - 1) + (-1,)) + bias_2 = bias_2.astype(self.dtype) + out += jnp.reshape(bias_2, (1,) * (out.ndim - 1) + (-1,)) out = checkpoint_name(out, ffn2_ckpt_name) diff --git a/transformer_engine/jax/mlp.py b/transformer_engine/jax/mlp.py index 3b531a6150..30f6d8456b 100644 --- a/transformer_engine/jax/mlp.py +++ b/transformer_engine/jax/mlp.py @@ -3,15 +3,15 @@ # See LICENSE for license information. """JAX MLP modules""" -from typing import List, Tuple +from typing import List, Tuple, Sequence, Union, Callable from functools import partial import jax import jax.numpy as jnp from jax.ad_checkpoint import checkpoint_name -from .cpp_extensions import cast_fp8, transpose, cast_transpose -from .cpp_extensions import gelu as te_gelu +from .cpp_extensions import cast_fp8, transpose, cast_transpose, dbias_cast_transpose +from .cpp_extensions import gelu from .cpp_extensions import gelu_fp8, dgelu, dgelu_dbias_cast_transpose from .cpp_extensions import gated_gelu, gated_gelu_fp8 from .cpp_extensions import dgated_gelu, dgated_gelu_cast_transpose @@ -23,369 +23,56 @@ from .sharding import with_sharding_constraint_by_logical_axes -def gelu(x: jnp.ndarray): - """ - Gelu - """ - output = _gelu(x) - return output - - -@partial(jax.custom_vjp) -def _gelu(x: jnp.ndarray): - - geglu_output, _ = _gelu_fwd_rule(x) - - return geglu_output - - -def _gelu_fwd_rule(x): - geglu_output = te_gelu(x) - return geglu_output, (x,) - - -def _gelu_bwd_rule(ctx, g): - x, = ctx - assert x.dtype == g.dtype - - dx = dgelu(g, x) - dx = jnp.reshape(dx, x.shape) - return (dx,) +activation_dict = { + ('gelu',): {'fwd': gelu, + "bwd": dgelu}, + ('gelu', 'linear'): {'fwd': gated_gelu, + 'bwd': dgated_gelu} +} +activation_fp8_dict = { + ('gelu',): {'fwd': gelu_fp8, + 'bwd': dgelu_dbias_cast_transpose}, + ('gelu', 'linear'): {'fwd': gated_gelu_fp8, + 'bwd': dgated_gelu_cast_transpose} +} -_gelu.defvjp(_gelu_fwd_rule, _gelu_bwd_rule) - -def geglu(x: jnp.ndarray): +def activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]): """ - Gated gelu + Activation Unit """ - assert x.shape[-2] == 2 # Linear + GeLU - - output = _geglu(x) - + if len(activation_type) > 1: + assert x.shape[-2] == 2 # Linear + GeLU + output = _activation_lu(x, activation_type) return output -@partial(jax.custom_vjp) -def _geglu(x: jnp.ndarray): +@partial(jax.custom_vjp, nondiff_argnums=(1,)) +def _activation_lu(x: jnp.ndarray, activation_type: Sequence[Union[str, Callable]]): - geglu_output, _ = _geglu_fwd_rule(x) + _output, _ = _activation_lu_fwd_rule(x, activation_type) - return geglu_output + return _output -def _geglu_fwd_rule(x): - geglu_output = gated_gelu(x) - return geglu_output, (x,) +def _activation_lu_fwd_rule(x, activation_type): + fwd_output = activation_dict[activation_type]["fwd"](x) + return fwd_output, (x,) -def _geglu_bwd_rule(ctx, g): +def _activation_lu_bwd_rule(activation_type, ctx, g): x, = ctx assert x.dtype == g.dtype - dx = dgated_gelu(g, x) + dx = activation_dict[activation_type]["bwd"](g, x) dx = jnp.reshape(dx, x.shape) return (dx,) +_activation_lu.defvjp(_activation_lu_fwd_rule, _activation_lu_bwd_rule) -_geglu.defvjp(_geglu_fwd_rule, _geglu_bwd_rule) - - -def layernorm_geglu_fp8_mlp(x: jnp.ndarray, - gamma: jnp.ndarray, - beta: jnp.ndarray, - kernels: List[jnp.ndarray], - fp8_gemm_pkg: FP8MetaPackage, - layernorm_type: str, - zero_centered_gamma: bool = False, - epsilon: float = 1e-6, - layernorm_input_axes: Tuple[str, ...] = None, - dot_1_input_axes: Tuple[str, ...] = None, - dot_2_input_axes: Tuple[str, ...] = None, - ffn1_ckpt_name: str = 'ffn1', - ffn2_ckpt_name: str = 'ffn2') -> jnp.ndarray: - """ - Layernorm + GEMM1 + GeGLU + GEMM2 - """ - - assert len(kernels) == 2 - assert fp8_gemm_pkg.num_of_gemm == len(kernels) - - kernel_1 = kernels[0] - kernel_2 = kernels[1] - fp8_max = fp8_gemm_pkg.fp8_max - amax = fp8_gemm_pkg.amax - scale = fp8_gemm_pkg.scale - scale_inv = fp8_gemm_pkg.scale_inv - - fwd_dtype = FP8Helper.FWD_DTYPE - bwd_dtype = FP8Helper.BWD_DTYPE - - layernorm_type = canonicalize_layernorm_type(layernorm_type) - if layernorm_type == 'rmsnorm': - assert beta is None, "beta should be None if layernorm_type is 'rmsnorm'" - assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ - "if layernorm_type is 'rmsnorm'" - - output = _layernorm_geglu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, scale, - scale_inv, fwd_dtype, bwd_dtype, layernorm_type, - zero_centered_gamma, epsilon, layernorm_input_axes, - dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, - ffn2_ckpt_name) - return output - - -@partial(jax.custom_vjp, nondiff_argnums=(9, 10, 11, 12, 13, 14, 15, 16, 17, 18)) -def _layernorm_geglu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, - kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, fp8_max: jnp.ndarray, - amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, - fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str, - zero_centered_gamma: bool, epsilon: float, - layernorm_input_axes: Tuple[str, ...], - dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...], - ffn1_ckpt_name: str, ffn2_ckpt_name: str): - output, _ = _layernorm_geglu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, fp8_max, amax, - scale, scale_inv, fwd_dtype, bwd_dtype, - layernorm_type, zero_centered_gamma, epsilon, - layernorm_input_axes, dot_1_input_axes, - dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name) - return output - - -def _layernorm_geglu_fp8_mlp_fwd_rule( - x, - gamma, - beta, - kernel_1, - kernel_2, - fp8_max, - amax, - scale, - scale_inv, - fwd_dtype, - bwd_dtype, # pylint: disable=unused-argument - layernorm_type, - zero_centered_gamma, - epsilon, - layernorm_input_axes, - dot_1_input_axes, - dot_2_input_axes, - ffn1_ckpt_name, - ffn2_ckpt_name): - - # x should be in shape of (batch..., hidden) - # Kernel_1 should be in shape of (Hidden_in, 2, Hidden_out) - # Kernel_2 should be in shape of (Hidden_in, Hidden_out) - assert len(kernel_1.shape) == 3 - assert kernel_1.shape[-2] == 2 - assert len(kernel_2.shape) == 2 - - x_contracting_dims = (len(x.shape) - 1,) - xt_batch_dims = tuple(range(1, x.ndim)) - assert x.shape[x_contracting_dims[0]] == kernel_1.shape[0] - assert kernel_1.shape[-1] == kernel_2.shape[0] - - amax = FP8Helper.update_amax_history(amax) - - gemm1_x_idx, gemm1_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(0) - - x_amax = amax[gemm1_x_idx, 0:1] - x_scale = scale[gemm1_x_idx] - x_scale_inv = scale_inv[gemm1_x_idx] - - x = with_sharding_constraint_by_logical_axes(x, layernorm_input_axes) - - if layernorm_type == 'layernorm': - ln_out, mu, rsigma, updated_x_amax = layernorm_fwd_fp8( - x, - gamma, - beta, - x_amax, - x_scale, - x_scale_inv, - out_dtype=fwd_dtype, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) - else: - assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ - "if layernorm_type is 'rmsnorm'" - ln_out, rsigma, updated_x_amax = rmsnorm_fwd_fp8(x, - gamma, - x_amax, - x_scale, - x_scale_inv, - out_dtype=fwd_dtype, - epsilon=epsilon) - mu = None - - assert x.shape == ln_out.shape - - kernel_1_amax = amax[gemm1_kernel_idx, 0:1] - kernel_1_scale = scale[gemm1_kernel_idx] - kernel_1_scale_inv = scale_inv[gemm1_kernel_idx] - - # Note (Ming Huang): Use cast only to allow XLA handle tranpose for avoiding - # unnecessary copy to break FP8 GEMM pattern matching. - casted_kernel_1, updated_kernel_1_amax = \ - cast_fp8(kernel_1, kernel_1_amax, kernel_1_scale, kernel_1_scale_inv, fwd_dtype) - - ln_out = with_sharding_constraint_by_logical_axes(ln_out, dot_1_input_axes) - - # (batch..., hidden_in) x (hidden_in, 2, hidden_out) - dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype, - (x_contracting_dims, (0,)), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) - dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) - - gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1) - - geglu_out_amax = amax[gemm2_x_idx, 0:1] - geglu_out_scale = scale[gemm2_x_idx] - geglu_out_scale_inv = scale_inv[gemm2_x_idx] - - # (batch..., hidden_in) -> (batch..., hidden) - casted_geglu_out, updated_geglu_amax = gated_gelu_fp8(dot_1_output, geglu_out_amax, - geglu_out_scale, geglu_out_scale_inv, - fwd_dtype) - - casted_geglu_out = with_sharding_constraint_by_logical_axes(casted_geglu_out, dot_2_input_axes) - - kernel_2_scale = scale[gemm2_kernel_idx] - kernel_2_scale_inv = scale_inv[gemm2_kernel_idx] - # Note (Ming Huang): Use native cast to allow XLA handle tranpose for avoiding - # unnecessary copy to break FP8 GEMM pattern matching. - casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale) - - # (batch..., hidden_in) x (hidden_out, hidden_in) - dot_2_output = fp8_dot_impl(casted_geglu_out, casted_kernel_2, geglu_out_scale_inv, - kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) - dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) - - ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, casted_kernel_1, - casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_geglu_amax, - updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims) - - return dot_2_output, ctx - - -def _layernorm_geglu_fp8_mlp_bwd_rule( - fwd_dtype, # pylint: disable=unused-argument - bwd_dtype, - layernorm_type, - zero_centered_gamma, - epsilon, - layernorm_input_axes, - dot_1_input_axes, - dot_2_input_axes, - ffn1_ckpt_name, # pylint: disable=unused-argument - ffn2_ckpt_name, # pylint: disable=unused-argument - ctx, - grad): - x, ln_out, mu, rsigma, gamma, dot_1_output, casted_geglu_out, \ - casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \ - updated_geglu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \ - x_contracting_dims, xt_batch_dims = ctx - - gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1) - - grad_amax = amax[gemm2_grad_idx, 0:1] - grad_scale = scale[gemm2_grad_idx] - grad_scale_inv = scale_inv[gemm2_grad_idx] - - # Since the sharding of outputs should be the same as dot_1's input - grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) - - casted_grad, casted_grad_t, updated_grad_amax = \ - cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype, - static_axis_boundary=-1, transpose_axis_boundary=-1) - - casted_geglu_out_t = transpose(casted_geglu_out, - static_axis_boundary=-1, - transpose_axis_boundary=-1) - - # (hidden, batch...,) x (hidden, batch...) - gemm2_x_scale_inv = scale_inv[gemm2_x_idx] - wgrad_2 = fp8_dot_impl(casted_geglu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv, - grad.dtype, (xt_batch_dims, xt_batch_dims), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) - - # (batch..., hidden_out) x (hidden_in, hidden_out) - kernel_2_scale_inv = scale_inv[gemm2_kernel_idx] - dgrad_2 = fp8_dot_impl(casted_grad, casted_kernel_2, grad_scale_inv, kernel_2_scale_inv, - grad.dtype, (x_contracting_dims, (1,)), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD)) - - dgrad_2 = with_sharding_constraint_by_logical_axes(dgrad_2, dot_2_input_axes) - - gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0) - - dgeglu_amax = amax[gemm1_grad_idx, 0:1] - dgeglu_scale = scale[gemm1_grad_idx] - dgeglu_scale_inv = scale_inv[gemm1_grad_idx] - - casted_dgeglu, casted_dgeglu_t, updated_dgeglu_amax = dgated_gelu_cast_transpose( - dgrad_2, - dot_1_output, - dgeglu_amax, - dgeglu_scale, - dgeglu_scale_inv, - bwd_dtype, - static_axis_boundary=-1) - - ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1) - - # (hidden, batch...) x (2, hidden, batch...) - xt_batch_dims_plus_act_dim = tuple(i + 1 for i in xt_batch_dims) - gemm1_x_scale_inv = scale_inv[gemm1_x_idx] - wgrad_1 = fp8_dot_impl(ln_out_t, casted_dgeglu_t, gemm1_x_scale_inv, dgeglu_scale_inv, - grad.dtype, (xt_batch_dims, xt_batch_dims_plus_act_dim), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) - - # (batch..., 2, hidden_out) x (hidden_in, 2, hidden_out) - x_contracting_dims_plus_act_dim = (min(x_contracting_dims),) + tuple( - i + 1 for i in x_contracting_dims) - kernel_1_scale_inv = scale_inv[gemm1_kernel_idx] - dgrad_1 = fp8_dot_impl(casted_dgeglu, casted_kernel_1, dgeglu_scale_inv, kernel_1_scale_inv, - grad.dtype, (x_contracting_dims_plus_act_dim, (1, 2)), - get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD)) - - dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes) - - if layernorm_type == 'layernorm': - dx, dgamma, dbeta = layernorm_bwd(dgrad_1, - x, - mu, - rsigma, - gamma, - zero_centered_gamma=zero_centered_gamma, - epsilon=epsilon) - else: - assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ - "if layernorm_type is 'rmsnorm'" - dx, dgamma = rmsnorm_bwd(dgrad_1, x, rsigma, gamma, epsilon=epsilon) - dbeta = None - - amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0]) - amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0]) - amax = amax.at[gemm1_grad_idx, 0].set(updated_dgeglu_amax[0]) - amax = amax.at[gemm2_x_idx, 0].set(updated_geglu_amax[0]) - amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax) - amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0]) - - scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) - - return dx, dgamma, dbeta, wgrad_1, wgrad_2, \ - fp8_max, amax, scale, scale_inv - - -_layernorm_geglu_fp8_mlp.defvjp(_layernorm_geglu_fp8_mlp_fwd_rule, - _layernorm_geglu_fp8_mlp_bwd_rule) - - -def layernorm_gelu_fp8_mlp(x: jnp.ndarray, +def fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, kernels: List[jnp.ndarray], @@ -398,9 +85,11 @@ def layernorm_gelu_fp8_mlp(x: jnp.ndarray, dot_1_input_axes: Tuple[str, ...] = None, dot_2_input_axes: Tuple[str, ...] = None, ffn1_ckpt_name: str = 'ffn1', - ffn2_ckpt_name: str = 'ffn2') -> jnp.ndarray: + ffn2_ckpt_name: str = 'ffn2', + activation_type: Sequence[Union[str, Callable]] = ('gelu',), + use_bias: bool = True) -> jnp.ndarray: """ - Layernorm + GEMM1 + bias + GeLU + GEMM2 + bias + Layernorm + GEMM1 + bias + activation + GEMM2 + bias """ assert len(kernels) == 2 @@ -424,32 +113,36 @@ def layernorm_gelu_fp8_mlp(x: jnp.ndarray, assert not zero_centered_gamma, "zero_centered_gamma is not supported " \ "if layernorm_type is 'rmsnorm'" - output = _layernorm_gelu_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max, + output = _fused_layernorm_fp8_mlp(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, fp8_max, amax, scale, scale_inv, fwd_dtype, bwd_dtype, layernorm_type, zero_centered_gamma, epsilon, layernorm_input_axes, dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, - ffn2_ckpt_name) + ffn2_ckpt_name, activation_type, use_bias) return output -@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20)) -def _layernorm_gelu_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, +@partial(jax.custom_vjp, nondiff_argnums=(11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22)) +def _fused_layernorm_fp8_mlp(x: jnp.ndarray, gamma: jnp.ndarray, beta: jnp.ndarray, kernel_1: jnp.ndarray, kernel_2: jnp.ndarray, bias_1: jnp.ndarray, bias_2: jnp.ndarray, fp8_max: jnp.ndarray, amax: jnp.ndarray, scale: jnp.ndarray, scale_inv: jnp.ndarray, fwd_dtype: jnp.dtype, bwd_dtype: jnp.dtype, layernorm_type: str, zero_centered_gamma: bool, epsilon: float, layernorm_input_axes: Tuple[str, ...], dot_1_input_axes: Tuple[str, ...], dot_2_input_axes: Tuple[str, ...], - ffn1_ckpt_name: str, ffn2_ckpt_name: str): - output, _ = _layernorm_gelu_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1, bias_2, - fp8_max, amax, scale, scale_inv, fwd_dtype, - bwd_dtype, layernorm_type, zero_centered_gamma, - epsilon, layernorm_input_axes, dot_1_input_axes, - dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name) + ffn1_ckpt_name: str, ffn2_ckpt_name: str, + activation_type: Sequence[Union[str, Callable]], + use_bias: bool): + output, _ = _fused_layernorm_fp8_mlp_fwd_rule(x, gamma, beta, kernel_1, kernel_2, bias_1, + bias_2, fp8_max, amax, scale, scale_inv, + fwd_dtype, bwd_dtype, layernorm_type, + zero_centered_gamma, epsilon, + layernorm_input_axes, dot_1_input_axes, + dot_2_input_axes, ffn1_ckpt_name, ffn2_ckpt_name, + activation_type, use_bias) return output -def _layernorm_gelu_fp8_mlp_fwd_rule( +def _fused_layernorm_fp8_mlp_fwd_rule( x, gamma, beta, @@ -470,13 +163,16 @@ def _layernorm_gelu_fp8_mlp_fwd_rule( dot_1_input_axes, dot_2_input_axes, ffn1_ckpt_name, - ffn2_ckpt_name): + ffn2_ckpt_name, + activation_type, + use_bias): + is_gated = len(activation_type) > 1 # x should be in shape of (batch..., hidden) # Kernel_1 should be in shape of (Hidden_in, 1, Hidden_out) # Kernel_2 should be in shape of (Hidden_in, Hidden_out) assert len(kernel_1.shape) == 3 - assert kernel_1.shape[-2] == 1 + assert kernel_1.shape[-2] == len(activation_type) assert len(kernel_2.shape) == 2 x_contracting_dims = (len(x.shape) - 1,) @@ -487,7 +183,8 @@ def _layernorm_gelu_fp8_mlp_fwd_rule( # Squeeze act axis # (hidden_in, 1, hidden_out) -> (hidden_in, hidden_out) - kernel_1 = jnp.squeeze(kernel_1, axis=-2) + if not is_gated: + kernel_1 = jnp.squeeze(kernel_1, axis=-2) amax = FP8Helper.update_amax_history(amax) @@ -539,22 +236,26 @@ def _layernorm_gelu_fp8_mlp_fwd_rule( dot_1_output = fp8_dot_impl(ln_out, casted_kernel_1, x_scale_inv, kernel_1_scale_inv, x.dtype, (x_contracting_dims, (0,)), get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) - - bias_1_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1.shape - dot_1_output += jnp.reshape(bias_1, bias_1_shape) + if use_bias: + bias_1_shape = (1,) * (dot_1_output.ndim - bias_1.ndim) + bias_1.shape + dot_1_output += jnp.reshape(bias_1, bias_1_shape) dot_1_output = checkpoint_name(dot_1_output, ffn1_ckpt_name) gemm2_x_idx, gemm2_kernel_idx, _ = FP8Helper.get_fp8_meta_indices(1) - gelu_out_amax = amax[gemm2_x_idx, 0:1] - gelu_out_scale = scale[gemm2_x_idx] - gelu_out_scale_inv = scale_inv[gemm2_x_idx] + activation_lu_out_amax = amax[gemm2_x_idx, 0:1] + activation_lu_out_scale = scale[gemm2_x_idx] + activation_lu_out_scale_inv = scale_inv[gemm2_x_idx] + + activation_lu_fp8 = activation_fp8_dict[activation_type]["fwd"] # (batch..., hidden_in) -> (batch..., hidden) - casted_gelu_out, updated_gelu_amax = gelu_fp8(dot_1_output, gelu_out_amax, gelu_out_scale, - gelu_out_scale_inv, fwd_dtype) + casted_activation_lu_out, updated_activation_lu_amax = activation_lu_fp8(dot_1_output, + activation_lu_out_amax, activation_lu_out_scale, + activation_lu_out_scale_inv, fwd_dtype) - casted_gelu_out = with_sharding_constraint_by_logical_axes(casted_gelu_out, dot_2_input_axes) + casted_activation_lu_out = with_sharding_constraint_by_logical_axes(casted_activation_lu_out, + dot_2_input_axes) kernel_2_scale = scale[gemm2_kernel_idx] kernel_2_scale_inv = scale_inv[gemm2_kernel_idx] @@ -563,23 +264,26 @@ def _layernorm_gelu_fp8_mlp_fwd_rule( casted_kernel_2, updated_kernel_2_amax = quantize(kernel_2, fwd_dtype, kernel_2_scale) # (batch..., hidden_in) x (hidden_out, hidden_in) - dot_2_output = fp8_dot_impl(casted_gelu_out, casted_kernel_2, gelu_out_scale_inv, + dot_2_output = fp8_dot_impl(casted_activation_lu_out, casted_kernel_2, + activation_lu_out_scale_inv, kernel_2_scale_inv, x.dtype, (x_contracting_dims, (0,)), get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_FPROP)) - bias_2_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2.shape - dot_2_output += jnp.reshape(bias_2, bias_2_shape) + if use_bias: + bias_2_shape = (1,) * (dot_2_output.ndim - bias_2.ndim) + bias_2.shape + dot_2_output += jnp.reshape(bias_2, bias_2_shape) + dot_2_output = checkpoint_name(dot_2_output, ffn2_ckpt_name) - ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_gelu_out, casted_kernel_1, - casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, updated_gelu_amax, - updated_kernel_1_amax, updated_kernel_2_amax, x_contracting_dims, xt_batch_dims, - bias_1.shape, bias_2.shape) + ctx = (x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, casted_kernel_1, + casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, + updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, + x_contracting_dims, xt_batch_dims, bias_1.shape, bias_2.shape) return dot_2_output, ctx -def _layernorm_gelu_fp8_mlp_bwd_rule( +def _fused_layernorm_fp8_mlp_bwd_rule( fwd_dtype, # pylint: disable=unused-argument bwd_dtype, layernorm_type, @@ -590,13 +294,17 @@ def _layernorm_gelu_fp8_mlp_bwd_rule( dot_2_input_axes, ffn1_ckpt_name, # pylint: disable=unused-argument ffn2_ckpt_name, # pylint: disable=unused-argument + activation_type, + use_bias, ctx, grad): - x, ln_out, mu, rsigma, gamma, dot_1_output, casted_gelu_out, \ + x, ln_out, mu, rsigma, gamma, dot_1_output, casted_activation_lu_out, \ casted_kernel_1, casted_kernel_2, fp8_max, amax, scale, scale_inv, updated_x_amax, \ - updated_gelu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \ + updated_activation_lu_amax, updated_kernel_1_amax, updated_kernel_2_amax, \ x_contracting_dims, xt_batch_dims, bias_1_shape, bias_2_shape= ctx + is_gated = len(activation_type) > 1 + gemm2_x_idx, gemm2_kernel_idx, gemm2_grad_idx = FP8Helper.get_fp8_meta_indices(1) grad_amax = amax[gemm2_grad_idx, 0:1] @@ -606,21 +314,29 @@ def _layernorm_gelu_fp8_mlp_bwd_rule( # Since the sharding of outputs should be the same as dot_1's input grad = with_sharding_constraint_by_logical_axes(grad, dot_1_input_axes) - casted_grad, casted_grad_t, updated_grad_amax = \ - cast_transpose(grad, grad_amax, grad_scale, grad_scale_inv, bwd_dtype, - static_axis_boundary=-1, transpose_axis_boundary=-1) - - casted_gelu_out_t = transpose(casted_gelu_out, - static_axis_boundary=-1, - transpose_axis_boundary=-1) + if use_bias: + casted_grad, casted_grad_t, dbias_2, updated_grad_amax = \ + dbias_cast_transpose(grad, grad_amax, grad_scale, + grad_scale_inv, bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1) + dbias_2 = jnp.reshape(dbias_2, bias_2_shape) + else: + casted_grad, casted_grad_t, updated_grad_amax = \ + cast_transpose(grad, grad_amax, grad_scale, + grad_scale_inv, bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1) + dbias_2 = jnp.empty(bias_2_shape, grad.dtype) - dbias_2 = jnp.sum(grad, axis=(i for i in range(grad.ndim - 1))) - dbias_2 = jnp.reshape(dbias_2, bias_2_shape) + casted_activation_lu_out_t = transpose(casted_activation_lu_out, + static_axis_boundary=-1, + transpose_axis_boundary=-1) # (hidden, batch...,) x (hidden, batch...) gemm2_x_scale_inv = scale_inv[gemm2_x_idx] - wgrad_2 = fp8_dot_impl(casted_gelu_out_t, casted_grad_t, gemm2_x_scale_inv, grad_scale_inv, - grad.dtype, (xt_batch_dims, xt_batch_dims), + wgrad_2 = fp8_dot_impl(casted_activation_lu_out_t, casted_grad_t, gemm2_x_scale_inv, + grad_scale_inv, grad.dtype, (xt_batch_dims, xt_batch_dims), get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) # (batch..., hidden_out) x (hidden_in, hidden_out) @@ -633,36 +349,85 @@ def _layernorm_gelu_fp8_mlp_bwd_rule( gemm1_x_idx, gemm1_kernel_idx, gemm1_grad_idx = FP8Helper.get_fp8_meta_indices(0) - dgelu_amax = amax[gemm1_grad_idx, 0:1] - dgelu_scale = scale[gemm1_grad_idx] - dgelu_scale_inv = scale_inv[gemm1_grad_idx] - - casted_dgelu, casted_dgelu_t, dbias_1, updated_dgelu_amax = dgelu_dbias_cast_transpose( - dgrad_2, - dot_1_output, - dgelu_amax, - dgelu_scale, - dgelu_scale_inv, - bwd_dtype, - static_axis_boundary=-1, - transpose_axis_boundary=-1) - - dbias_1 = jnp.reshape(dbias_1, bias_1_shape) + dactivation_lu_amax = amax[gemm1_grad_idx, 0:1] + dactivation_lu_scale = scale[gemm1_grad_idx] + dactivation_lu_scale_inv = scale_inv[gemm1_grad_idx] + + dactivation_lu_cast_transpose = activation_fp8_dict[activation_type]["bwd"] + dactivation_lu = activation_dict[activation_type]["bwd"](dgrad_2, dot_1_output) + + if is_gated: + if use_bias: + casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \ + dbias_cast_transpose( + dactivation_lu, + dactivation_lu_amax, + dactivation_lu_scale, + dactivation_lu_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-2) + dbias_1 = jnp.reshape(dbias_1, bias_1_shape) + else: + casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \ + dactivation_lu_cast_transpose( + dgrad_2, + dot_1_output, + dactivation_lu_amax, + dactivation_lu_scale, + dactivation_lu_scale_inv, + bwd_dtype, + static_axis_boundary=-1) + dbias_1 = jnp.empty(bias_1_shape, bwd_dtype) + else: + if use_bias: + casted_dactivation_lu, casted_dactivation_lu_t, dbias_1, updated_dactivation_lu_amax = \ + dactivation_lu_cast_transpose( + dgrad_2, + dot_1_output, + dactivation_lu_amax, + dactivation_lu_scale, + dactivation_lu_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1) + dbias_1 = jnp.reshape(dbias_1, bias_1_shape) + else: + casted_dactivation_lu, casted_dactivation_lu_t, updated_dactivation_lu_amax = \ + cast_transpose( + dactivation_lu, + dactivation_lu_amax, + dactivation_lu_scale, + dactivation_lu_scale_inv, + bwd_dtype, + static_axis_boundary=-1, + transpose_axis_boundary=-1) + dbias_1 = jnp.empty(bias_1_shape, bwd_dtype) ln_out_t = transpose(ln_out, static_axis_boundary=-1, transpose_axis_boundary=-1) # (hidden, batch...) x (hidden, batch...) gemm1_x_scale_inv = scale_inv[gemm1_x_idx] - wgrad_1 = fp8_dot_impl(ln_out_t, casted_dgelu_t, gemm1_x_scale_inv, dgelu_scale_inv, grad.dtype, - (xt_batch_dims, xt_batch_dims), + xt_batch_dims_2 = xt_batch_dims if not is_gated \ + else tuple(i + 1 for i in xt_batch_dims) + wgrad_1 = fp8_dot_impl(ln_out_t, casted_dactivation_lu_t, gemm1_x_scale_inv, + dactivation_lu_scale_inv, grad.dtype, + (xt_batch_dims, xt_batch_dims_2), get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_WGRAD)) # Expand act axis to match the shape with the given kernel_1 - wgrad_1 = jnp.expand_dims(wgrad_1, axis=-2) + if not is_gated: + wgrad_1 = jnp.expand_dims(wgrad_1, axis=-2) # (batch..., hidden_out) x (hidden_in, hidden_out) + if is_gated: + x_contracting_dims = ((min(x_contracting_dims),) + tuple( + i + 1 for i in x_contracting_dims), (1,2)) + else: + x_contracting_dims = (x_contracting_dims, (1,)) kernel_1_scale_inv = scale_inv[gemm1_kernel_idx] - dgrad_1 = fp8_dot_impl(casted_dgelu, casted_kernel_1, dgelu_scale_inv, kernel_1_scale_inv, - grad.dtype, (x_contracting_dims, (1,)), + dgrad_1 = fp8_dot_impl(casted_dactivation_lu, casted_kernel_1, + dactivation_lu_scale_inv, kernel_1_scale_inv, + grad.dtype, x_contracting_dims, get_precision_of_fp8_dot(FP8Helper.FP8_2X_ACC_DGRAD)) dgrad_1 = with_sharding_constraint_by_logical_axes(dgrad_1, layernorm_input_axes) @@ -683,15 +448,15 @@ def _layernorm_gelu_fp8_mlp_bwd_rule( amax = amax.at[gemm1_x_idx, 0].set(updated_x_amax[0]) amax = amax.at[gemm1_kernel_idx, 0].set(updated_kernel_1_amax[0]) - amax = amax.at[gemm1_grad_idx, 0].set(updated_dgelu_amax[0]) - amax = amax.at[gemm2_x_idx, 0].set(updated_gelu_amax[0]) + amax = amax.at[gemm1_grad_idx, 0].set(updated_dactivation_lu_amax[0]) + amax = amax.at[gemm2_x_idx, 0].set(updated_activation_lu_amax[0]) amax = amax.at[gemm2_kernel_idx, 0].set(updated_kernel_2_amax) amax = amax.at[gemm2_grad_idx, 0].set(updated_grad_amax[0]) scale, scale_inv = FP8Helper.update_fp8_scale(fp8_max, amax, scale) - return dx, dgamma, dbeta, wgrad_1, wgrad_2, dbias_1, dbias_2, \ fp8_max, amax, scale, scale_inv -_layernorm_gelu_fp8_mlp.defvjp(_layernorm_gelu_fp8_mlp_fwd_rule, _layernorm_gelu_fp8_mlp_bwd_rule) +_fused_layernorm_fp8_mlp.defvjp(_fused_layernorm_fp8_mlp_fwd_rule, + _fused_layernorm_fp8_mlp_bwd_rule)